diff options
author | Shauren <shauren.trinity@gmail.com> | 2016-03-10 23:26:26 +0100 |
---|---|---|
committer | Shauren <shauren.trinity@gmail.com> | 2016-03-10 23:33:56 +0100 |
commit | f123c396591ffb50fa7e02365235740df618f579 (patch) | |
tree | 20eeb1e0c76814267e204b2314cf42db21d5aab4 /src | |
parent | 52bb6485417494e58361a75fcf73cf5f0a1d09c5 (diff) |
Core/Networking: Added new AsyncRead method to Socket class allowing to pass a custom completion handler and refactor world socket initialization string handling
Diffstat (limited to 'src')
-rw-r--r-- | src/server/game/Server/WorldSocket.cpp | 272 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocket.h | 3 | ||||
-rw-r--r-- | src/server/shared/Networking/Socket.h | 11 |
3 files changed, 164 insertions, 122 deletions
diff --git a/src/server/game/Server/WorldSocket.cpp b/src/server/game/Server/WorldSocket.cpp index acec25a7363..843b94bf628 100644 --- a/src/server/game/Server/WorldSocket.cpp +++ b/src/server/game/Server/WorldSocket.cpp @@ -57,18 +57,18 @@ std::string const WorldSocket::ServerConnectionInitialize("WORLD OF WARCRAFT CON std::string const WorldSocket::ClientConnectionInitialize("WORLD OF WARCRAFT CONNECTION - CLIENT TO SERVER"); uint32 const WorldSocket::MinSizeForCompression = 0x400; -uint32 const SizeOfClientHeader[2][2] = +uint32 const SizeOfClientHeader[2] = { - { 2, 0 }, - { 6, 4 } + 6, 4 }; uint32 const SizeOfServerHeader[2] = { sizeof(uint16) + sizeof(uint32), sizeof(uint32) }; + WorldSocket::WorldSocket(tcp::socket&& socket) : Socket(std::move(socket)), _type(CONNECTION_TYPE_REALM), _authSeed(rand32()), _OverSpeedPings(0), - _worldSession(nullptr), _authed(false), _compressionStream(nullptr), _initialized(false) + _worldSession(nullptr), _authed(false), _compressionStream(nullptr) { - _headerBuffer.Resize(SizeOfClientHeader[0][0]); + _headerBuffer.Resize(SizeOfClientHeader[0]); } WorldSocket::~WorldSocket() @@ -116,7 +116,9 @@ void WorldSocket::CheckIpCallback(PreparedQueryResult result) } } - AsyncRead(); + _packetBuffer.Resize(2 + ClientConnectionInitialize.length() + 1); + + AsyncReadWithCallback(&WorldSocket::InitializeHandler); MessageBuffer initializer; ServerPktHeader header; @@ -128,6 +130,65 @@ void WorldSocket::CheckIpCallback(PreparedQueryResult result) QueuePacket(std::move(initializer)); } +void WorldSocket::InitializeHandler(boost::system::error_code error, std::size_t transferedBytes) +{ + if (error) + { + CloseSocket(); + return; + } + + GetReadBuffer().WriteCompleted(transferedBytes); + + MessageBuffer& packet = GetReadBuffer(); + if (packet.GetActiveSize() > 0) + { + if (_packetBuffer.GetRemainingSpace() > 0) + { + // need to receive the header + std::size_t readHeaderSize = std::min(packet.GetActiveSize(), _packetBuffer.GetRemainingSpace()); + _packetBuffer.Write(packet.GetReadPointer(), readHeaderSize); + packet.ReadCompleted(readHeaderSize); + + if (_packetBuffer.GetRemainingSpace() > 0) + { + // Couldn't receive the whole header this time. + ASSERT(packet.GetActiveSize() == 0); + AsyncReadWithCallback(&WorldSocket::InitializeHandler); + return; + } + + std::string initializer(reinterpret_cast<char const*>(_packetBuffer.GetReadPointer() + 2), std::min(_packetBuffer.GetActiveSize() - 2, ClientConnectionInitialize.length())); + if (initializer != ClientConnectionInitialize) + { + CloseSocket(); + return; + } + + _compressionStream = new z_stream(); + _compressionStream->zalloc = (alloc_func)NULL; + _compressionStream->zfree = (free_func)NULL; + _compressionStream->opaque = (voidpf)NULL; + _compressionStream->avail_in = 0; + _compressionStream->next_in = NULL; + int32 z_res = deflateInit2(_compressionStream, sWorld->getIntConfig(CONFIG_COMPRESSION), Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY); + if (z_res != Z_OK) + { + CloseSocket(); + TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: %i (%s)", z_res, zError(z_res)); + return; + } + + _packetBuffer.Reset(); + HandleSendAuthSession(); + AsyncRead(); + return; + } + } + + AsyncReadWithCallback(&WorldSocket::InitializeHandler); +} + bool WorldSocket::Update() { EncryptablePacket* queued; @@ -266,9 +327,7 @@ void WorldSocket::ExtractOpcodeAndSize(ClientPktHeader const* header, uint32& op else { opcode = header->Setup.Command; - size = header->Setup.Size; - if (_initialized) - size -= 4; + size = header->Setup.Size - 4; } } @@ -281,7 +340,7 @@ void WorldSocket::SetWorldSession(WorldSession* session) bool WorldSocket::ReadHeaderHandler() { - ASSERT(_headerBuffer.GetActiveSize() == SizeOfClientHeader[_initialized][_authCrypt.IsInitialized()], "Header size " SZFMTD " different than expected %u", _headerBuffer.GetActiveSize(), SizeOfClientHeader[_initialized][_authCrypt.IsInitialized()]); + ASSERT(_headerBuffer.GetActiveSize() == SizeOfClientHeader[_authCrypt.IsInitialized()], "Header size " SZFMTD " different than expected %u", _headerBuffer.GetActiveSize(), SizeOfClientHeader[_authCrypt.IsInitialized()]); _authCrypt.DecryptRecv(_headerBuffer.GetReadPointer(), _headerBuffer.GetActiveSize()); @@ -291,7 +350,7 @@ bool WorldSocket::ReadHeaderHandler() ExtractOpcodeAndSize(header, opcode, size); - if (!ClientPktHeader::IsValidSize(size) || (_initialized && !ClientPktHeader::IsValidOpcode(opcode))) + if (!ClientPktHeader::IsValidSize(size) || !ClientPktHeader::IsValidOpcode(opcode)) { TC_LOG_ERROR("network", "WorldSocket::ReadHeaderHandler(): client %s sent malformed packet (size: %u, cmd: %u)", GetRemoteIpAddress().to_string().c_str(), size, opcode); @@ -304,133 +363,106 @@ bool WorldSocket::ReadHeaderHandler() WorldSocket::ReadDataHandlerResult WorldSocket::ReadDataHandler() { - if (_initialized) - { - ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(_headerBuffer.GetReadPointer()); - uint32 cmd; - uint32 size; + ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(_headerBuffer.GetReadPointer()); + uint32 cmd; + uint32 size; - ExtractOpcodeAndSize(header, cmd, size); + ExtractOpcodeAndSize(header, cmd, size); - OpcodeClient opcode = static_cast<OpcodeClient>(cmd); + OpcodeClient opcode = static_cast<OpcodeClient>(cmd); - WorldPacket packet(opcode, std::move(_packetBuffer), GetConnectionType()); + WorldPacket packet(opcode, std::move(_packetBuffer), GetConnectionType()); - if (sPacketLog->CanLogPacket()) - sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort(), GetConnectionType()); + if (sPacketLog->CanLogPacket()) + sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort(), GetConnectionType()); - std::unique_lock<std::mutex> sessionGuard(_worldSessionLock, std::defer_lock); + std::unique_lock<std::mutex> sessionGuard(_worldSessionLock, std::defer_lock); - switch (opcode) + switch (opcode) + { + case CMSG_PING: + LogOpcodeText(opcode, sessionGuard); + return HandlePing(packet) ? ReadDataHandlerResult::Ok : ReadDataHandlerResult::Error; + case CMSG_AUTH_SESSION: { - case CMSG_PING: - LogOpcodeText(opcode, sessionGuard); - return HandlePing(packet) ? ReadDataHandlerResult::Ok : ReadDataHandlerResult::Error; - case CMSG_AUTH_SESSION: + LogOpcodeText(opcode, sessionGuard); + if (_authed) { - LogOpcodeText(opcode, sessionGuard); - if (_authed) - { - // locking just to safely log offending user is probably overkill but we are disconnecting him anyway - if (sessionGuard.try_lock()) - TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str()); - return ReadDataHandlerResult::Error; - } - - std::shared_ptr<WorldPackets::Auth::AuthSession> authSession = std::make_shared<WorldPackets::Auth::AuthSession>(std::move(packet)); - authSession->Read(); - HandleAuthSession(authSession); - return ReadDataHandlerResult::WaitingForQuery; + // locking just to safely log offending user is probably overkill but we are disconnecting him anyway + if (sessionGuard.try_lock()) + TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str()); + return ReadDataHandlerResult::Error; } - case CMSG_AUTH_CONTINUED_SESSION: - { - LogOpcodeText(opcode, sessionGuard); - if (_authed) - { - // locking just to safely log offending user is probably overkill but we are disconnecting him anyway - if (sessionGuard.try_lock()) - TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_CONTINUED_SESSION from %s", _worldSession->GetPlayerInfo().c_str()); - return ReadDataHandlerResult::Error; - } - std::shared_ptr<WorldPackets::Auth::AuthContinuedSession> authSession = std::make_shared<WorldPackets::Auth::AuthContinuedSession>(std::move(packet)); - authSession->Read(); - HandleAuthContinuedSession(authSession); - return ReadDataHandlerResult::WaitingForQuery; - } - case CMSG_KEEP_ALIVE: - LogOpcodeText(opcode, sessionGuard); - break; - case CMSG_LOG_DISCONNECT: - LogOpcodeText(opcode, sessionGuard); - packet.rfinish(); // contains uint32 disconnectReason; - break; - case CMSG_ENABLE_NAGLE: - LogOpcodeText(opcode, sessionGuard); - SetNoDelay(false); - break; - case CMSG_CONNECT_TO_FAILED: + std::shared_ptr<WorldPackets::Auth::AuthSession> authSession = std::make_shared<WorldPackets::Auth::AuthSession>(std::move(packet)); + authSession->Read(); + HandleAuthSession(authSession); + return ReadDataHandlerResult::WaitingForQuery; + } + case CMSG_AUTH_CONTINUED_SESSION: + { + LogOpcodeText(opcode, sessionGuard); + if (_authed) { - sessionGuard.lock(); - - LogOpcodeText(opcode, sessionGuard); - WorldPackets::Auth::ConnectToFailed connectToFailed(std::move(packet)); - connectToFailed.Read(); - HandleConnectToFailed(connectToFailed); - break; + // locking just to safely log offending user is probably overkill but we are disconnecting him anyway + if (sessionGuard.try_lock()) + TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_CONTINUED_SESSION from %s", _worldSession->GetPlayerInfo().c_str()); + return ReadDataHandlerResult::Error; } - default: - { - sessionGuard.lock(); - LogOpcodeText(opcode, sessionGuard); + std::shared_ptr<WorldPackets::Auth::AuthContinuedSession> authSession = std::make_shared<WorldPackets::Auth::AuthContinuedSession>(std::move(packet)); + authSession->Read(); + HandleAuthContinuedSession(authSession); + return ReadDataHandlerResult::WaitingForQuery; + } + case CMSG_KEEP_ALIVE: + LogOpcodeText(opcode, sessionGuard); + break; + case CMSG_LOG_DISCONNECT: + LogOpcodeText(opcode, sessionGuard); + packet.rfinish(); // contains uint32 disconnectReason; + break; + case CMSG_ENABLE_NAGLE: + LogOpcodeText(opcode, sessionGuard); + SetNoDelay(false); + break; + case CMSG_CONNECT_TO_FAILED: + { + sessionGuard.lock(); - if (!_worldSession) - { - TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode)); - return ReadDataHandlerResult::Error; - } + LogOpcodeText(opcode, sessionGuard); + WorldPackets::Auth::ConnectToFailed connectToFailed(std::move(packet)); + connectToFailed.Read(); + HandleConnectToFailed(connectToFailed); + break; + } + default: + { + sessionGuard.lock(); - OpcodeHandler const* handler = opcodeTable[opcode]; - if (!handler) - { - TC_LOG_ERROR("network.opcode", "No defined handler for opcode %s sent by %s", GetOpcodeNameForLogging(static_cast<OpcodeClient>(packet.GetOpcode())).c_str(), _worldSession->GetPlayerInfo().c_str()); - break; - } + LogOpcodeText(opcode, sessionGuard); - // Our Idle timer will reset on any non PING opcodes. - // Catches people idling on the login screen and any lingering ingame connections. - _worldSession->ResetTimeOutTime(); + if (!_worldSession) + { + TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode)); + return ReadDataHandlerResult::Error; + } - // Copy the packet to the heap before enqueuing - _worldSession->QueuePacket(new WorldPacket(std::move(packet))); + OpcodeHandler const* handler = opcodeTable[opcode]; + if (!handler) + { + TC_LOG_ERROR("network.opcode", "No defined handler for opcode %s sent by %s", GetOpcodeNameForLogging(static_cast<OpcodeClient>(packet.GetOpcode())).c_str(), _worldSession->GetPlayerInfo().c_str()); break; } - } - } - else - { - std::string initializer(reinterpret_cast<char const*>(_packetBuffer.GetReadPointer()), std::min(_packetBuffer.GetActiveSize(), ClientConnectionInitialize.length())); - if (initializer != ClientConnectionInitialize) - return ReadDataHandlerResult::Error; - - _compressionStream = new z_stream(); - _compressionStream->zalloc = (alloc_func)NULL; - _compressionStream->zfree = (free_func)NULL; - _compressionStream->opaque = (voidpf)NULL; - _compressionStream->avail_in = 0; - _compressionStream->next_in = NULL; - int32 z_res = deflateInit2(_compressionStream, sWorld->getIntConfig(CONFIG_COMPRESSION), Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY); - if (z_res != Z_OK) - { - TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: %i (%s)", z_res, zError(z_res)); - return ReadDataHandlerResult::Error; - } - _initialized = true; - _headerBuffer.Resize(SizeOfClientHeader[1][0]); - _packetBuffer.Reset(); - HandleSendAuthSession(); + // Our Idle timer will reset on any non PING opcodes. + // Catches people idling on the login screen and any lingering ingame connections. + _worldSession->ResetTimeOutTime(); + + // Copy the packet to the heap before enqueuing + _worldSession->QueuePacket(new WorldPacket(std::move(packet))); + break; + } } return ReadDataHandlerResult::Ok; @@ -610,7 +642,7 @@ struct AccountInfo void WorldSocket::HandleAuthSession(std::shared_ptr<WorldPackets::Auth::AuthSession> authSession) { // Client switches packet headers after sending CMSG_AUTH_SESSION - _headerBuffer.Resize(SizeOfClientHeader[1][1]); + _headerBuffer.Resize(SizeOfClientHeader[1]); // Get the account information from the auth database PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_ACCOUNT_INFO_BY_NAME); @@ -811,7 +843,7 @@ void WorldSocket::HandleAuthContinuedSession(std::shared_ptr<WorldPackets::Auth: } // Client switches packet headers after sending CMSG_AUTH_CONTINUED_SESSION - _headerBuffer.Resize(SizeOfClientHeader[1][1]); + _headerBuffer.Resize(SizeOfClientHeader[1]); uint32 accountId = uint32(key.Fields.AccountId); PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_ACCOUNT_INFO_CONTINUED_SESSION); diff --git a/src/server/game/Server/WorldSocket.h b/src/server/game/Server/WorldSocket.h index 205494ca4ea..0191d4d04d5 100644 --- a/src/server/game/Server/WorldSocket.h +++ b/src/server/game/Server/WorldSocket.h @@ -107,6 +107,7 @@ protected: ReadDataHandlerResult ReadDataHandler(); private: void CheckIpCallback(PreparedQueryResult result); + void InitializeHandler(boost::system::error_code error, std::size_t transferedBytes); /// writes network.opcode log /// accessing WorldSession is not threadsafe, only do it when holding _worldSessionLock @@ -148,8 +149,6 @@ private: z_stream_s* _compressionStream; - bool _initialized; - PreparedQueryResultFuture _queryFuture; std::function<void(PreparedQueryResult&&)> _queryCallback; std::string _ipCountry; diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h index 07f427652aa..80655441ec2 100644 --- a/src/server/shared/Networking/Socket.h +++ b/src/server/shared/Networking/Socket.h @@ -90,6 +90,17 @@ public: std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); } + void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t)) + { + if (!IsOpen()) + return; + + _readBuffer.Normalize(); + _readBuffer.EnsureFreeSpace(); + _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), + std::bind(callback, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); + } + void QueuePacket(MessageBuffer&& buffer) { _writeQueue.push(std::move(buffer)); |