aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorShauren <shauren.trinity@gmail.com>2016-03-10 23:26:26 +0100
committerShauren <shauren.trinity@gmail.com>2016-03-10 23:33:56 +0100
commitf123c396591ffb50fa7e02365235740df618f579 (patch)
tree20eeb1e0c76814267e204b2314cf42db21d5aab4 /src
parent52bb6485417494e58361a75fcf73cf5f0a1d09c5 (diff)
Core/Networking: Added new AsyncRead method to Socket class allowing to pass a custom completion handler and refactor world socket initialization string handling
Diffstat (limited to 'src')
-rw-r--r--src/server/game/Server/WorldSocket.cpp272
-rw-r--r--src/server/game/Server/WorldSocket.h3
-rw-r--r--src/server/shared/Networking/Socket.h11
3 files changed, 164 insertions, 122 deletions
diff --git a/src/server/game/Server/WorldSocket.cpp b/src/server/game/Server/WorldSocket.cpp
index acec25a7363..843b94bf628 100644
--- a/src/server/game/Server/WorldSocket.cpp
+++ b/src/server/game/Server/WorldSocket.cpp
@@ -57,18 +57,18 @@ std::string const WorldSocket::ServerConnectionInitialize("WORLD OF WARCRAFT CON
std::string const WorldSocket::ClientConnectionInitialize("WORLD OF WARCRAFT CONNECTION - CLIENT TO SERVER");
uint32 const WorldSocket::MinSizeForCompression = 0x400;
-uint32 const SizeOfClientHeader[2][2] =
+uint32 const SizeOfClientHeader[2] =
{
- { 2, 0 },
- { 6, 4 }
+ 6, 4
};
uint32 const SizeOfServerHeader[2] = { sizeof(uint16) + sizeof(uint32), sizeof(uint32) };
+
WorldSocket::WorldSocket(tcp::socket&& socket) : Socket(std::move(socket)),
_type(CONNECTION_TYPE_REALM), _authSeed(rand32()), _OverSpeedPings(0),
- _worldSession(nullptr), _authed(false), _compressionStream(nullptr), _initialized(false)
+ _worldSession(nullptr), _authed(false), _compressionStream(nullptr)
{
- _headerBuffer.Resize(SizeOfClientHeader[0][0]);
+ _headerBuffer.Resize(SizeOfClientHeader[0]);
}
WorldSocket::~WorldSocket()
@@ -116,7 +116,9 @@ void WorldSocket::CheckIpCallback(PreparedQueryResult result)
}
}
- AsyncRead();
+ _packetBuffer.Resize(2 + ClientConnectionInitialize.length() + 1);
+
+ AsyncReadWithCallback(&WorldSocket::InitializeHandler);
MessageBuffer initializer;
ServerPktHeader header;
@@ -128,6 +130,65 @@ void WorldSocket::CheckIpCallback(PreparedQueryResult result)
QueuePacket(std::move(initializer));
}
+void WorldSocket::InitializeHandler(boost::system::error_code error, std::size_t transferedBytes)
+{
+ if (error)
+ {
+ CloseSocket();
+ return;
+ }
+
+ GetReadBuffer().WriteCompleted(transferedBytes);
+
+ MessageBuffer& packet = GetReadBuffer();
+ if (packet.GetActiveSize() > 0)
+ {
+ if (_packetBuffer.GetRemainingSpace() > 0)
+ {
+ // need to receive the header
+ std::size_t readHeaderSize = std::min(packet.GetActiveSize(), _packetBuffer.GetRemainingSpace());
+ _packetBuffer.Write(packet.GetReadPointer(), readHeaderSize);
+ packet.ReadCompleted(readHeaderSize);
+
+ if (_packetBuffer.GetRemainingSpace() > 0)
+ {
+ // Couldn't receive the whole header this time.
+ ASSERT(packet.GetActiveSize() == 0);
+ AsyncReadWithCallback(&WorldSocket::InitializeHandler);
+ return;
+ }
+
+ std::string initializer(reinterpret_cast<char const*>(_packetBuffer.GetReadPointer() + 2), std::min(_packetBuffer.GetActiveSize() - 2, ClientConnectionInitialize.length()));
+ if (initializer != ClientConnectionInitialize)
+ {
+ CloseSocket();
+ return;
+ }
+
+ _compressionStream = new z_stream();
+ _compressionStream->zalloc = (alloc_func)NULL;
+ _compressionStream->zfree = (free_func)NULL;
+ _compressionStream->opaque = (voidpf)NULL;
+ _compressionStream->avail_in = 0;
+ _compressionStream->next_in = NULL;
+ int32 z_res = deflateInit2(_compressionStream, sWorld->getIntConfig(CONFIG_COMPRESSION), Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY);
+ if (z_res != Z_OK)
+ {
+ CloseSocket();
+ TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: %i (%s)", z_res, zError(z_res));
+ return;
+ }
+
+ _packetBuffer.Reset();
+ HandleSendAuthSession();
+ AsyncRead();
+ return;
+ }
+ }
+
+ AsyncReadWithCallback(&WorldSocket::InitializeHandler);
+}
+
bool WorldSocket::Update()
{
EncryptablePacket* queued;
@@ -266,9 +327,7 @@ void WorldSocket::ExtractOpcodeAndSize(ClientPktHeader const* header, uint32& op
else
{
opcode = header->Setup.Command;
- size = header->Setup.Size;
- if (_initialized)
- size -= 4;
+ size = header->Setup.Size - 4;
}
}
@@ -281,7 +340,7 @@ void WorldSocket::SetWorldSession(WorldSession* session)
bool WorldSocket::ReadHeaderHandler()
{
- ASSERT(_headerBuffer.GetActiveSize() == SizeOfClientHeader[_initialized][_authCrypt.IsInitialized()], "Header size " SZFMTD " different than expected %u", _headerBuffer.GetActiveSize(), SizeOfClientHeader[_initialized][_authCrypt.IsInitialized()]);
+ ASSERT(_headerBuffer.GetActiveSize() == SizeOfClientHeader[_authCrypt.IsInitialized()], "Header size " SZFMTD " different than expected %u", _headerBuffer.GetActiveSize(), SizeOfClientHeader[_authCrypt.IsInitialized()]);
_authCrypt.DecryptRecv(_headerBuffer.GetReadPointer(), _headerBuffer.GetActiveSize());
@@ -291,7 +350,7 @@ bool WorldSocket::ReadHeaderHandler()
ExtractOpcodeAndSize(header, opcode, size);
- if (!ClientPktHeader::IsValidSize(size) || (_initialized && !ClientPktHeader::IsValidOpcode(opcode)))
+ if (!ClientPktHeader::IsValidSize(size) || !ClientPktHeader::IsValidOpcode(opcode))
{
TC_LOG_ERROR("network", "WorldSocket::ReadHeaderHandler(): client %s sent malformed packet (size: %u, cmd: %u)",
GetRemoteIpAddress().to_string().c_str(), size, opcode);
@@ -304,133 +363,106 @@ bool WorldSocket::ReadHeaderHandler()
WorldSocket::ReadDataHandlerResult WorldSocket::ReadDataHandler()
{
- if (_initialized)
- {
- ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(_headerBuffer.GetReadPointer());
- uint32 cmd;
- uint32 size;
+ ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(_headerBuffer.GetReadPointer());
+ uint32 cmd;
+ uint32 size;
- ExtractOpcodeAndSize(header, cmd, size);
+ ExtractOpcodeAndSize(header, cmd, size);
- OpcodeClient opcode = static_cast<OpcodeClient>(cmd);
+ OpcodeClient opcode = static_cast<OpcodeClient>(cmd);
- WorldPacket packet(opcode, std::move(_packetBuffer), GetConnectionType());
+ WorldPacket packet(opcode, std::move(_packetBuffer), GetConnectionType());
- if (sPacketLog->CanLogPacket())
- sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort(), GetConnectionType());
+ if (sPacketLog->CanLogPacket())
+ sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort(), GetConnectionType());
- std::unique_lock<std::mutex> sessionGuard(_worldSessionLock, std::defer_lock);
+ std::unique_lock<std::mutex> sessionGuard(_worldSessionLock, std::defer_lock);
- switch (opcode)
+ switch (opcode)
+ {
+ case CMSG_PING:
+ LogOpcodeText(opcode, sessionGuard);
+ return HandlePing(packet) ? ReadDataHandlerResult::Ok : ReadDataHandlerResult::Error;
+ case CMSG_AUTH_SESSION:
{
- case CMSG_PING:
- LogOpcodeText(opcode, sessionGuard);
- return HandlePing(packet) ? ReadDataHandlerResult::Ok : ReadDataHandlerResult::Error;
- case CMSG_AUTH_SESSION:
+ LogOpcodeText(opcode, sessionGuard);
+ if (_authed)
{
- LogOpcodeText(opcode, sessionGuard);
- if (_authed)
- {
- // locking just to safely log offending user is probably overkill but we are disconnecting him anyway
- if (sessionGuard.try_lock())
- TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
- return ReadDataHandlerResult::Error;
- }
-
- std::shared_ptr<WorldPackets::Auth::AuthSession> authSession = std::make_shared<WorldPackets::Auth::AuthSession>(std::move(packet));
- authSession->Read();
- HandleAuthSession(authSession);
- return ReadDataHandlerResult::WaitingForQuery;
+ // locking just to safely log offending user is probably overkill but we are disconnecting him anyway
+ if (sessionGuard.try_lock())
+ TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
+ return ReadDataHandlerResult::Error;
}
- case CMSG_AUTH_CONTINUED_SESSION:
- {
- LogOpcodeText(opcode, sessionGuard);
- if (_authed)
- {
- // locking just to safely log offending user is probably overkill but we are disconnecting him anyway
- if (sessionGuard.try_lock())
- TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_CONTINUED_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
- return ReadDataHandlerResult::Error;
- }
- std::shared_ptr<WorldPackets::Auth::AuthContinuedSession> authSession = std::make_shared<WorldPackets::Auth::AuthContinuedSession>(std::move(packet));
- authSession->Read();
- HandleAuthContinuedSession(authSession);
- return ReadDataHandlerResult::WaitingForQuery;
- }
- case CMSG_KEEP_ALIVE:
- LogOpcodeText(opcode, sessionGuard);
- break;
- case CMSG_LOG_DISCONNECT:
- LogOpcodeText(opcode, sessionGuard);
- packet.rfinish(); // contains uint32 disconnectReason;
- break;
- case CMSG_ENABLE_NAGLE:
- LogOpcodeText(opcode, sessionGuard);
- SetNoDelay(false);
- break;
- case CMSG_CONNECT_TO_FAILED:
+ std::shared_ptr<WorldPackets::Auth::AuthSession> authSession = std::make_shared<WorldPackets::Auth::AuthSession>(std::move(packet));
+ authSession->Read();
+ HandleAuthSession(authSession);
+ return ReadDataHandlerResult::WaitingForQuery;
+ }
+ case CMSG_AUTH_CONTINUED_SESSION:
+ {
+ LogOpcodeText(opcode, sessionGuard);
+ if (_authed)
{
- sessionGuard.lock();
-
- LogOpcodeText(opcode, sessionGuard);
- WorldPackets::Auth::ConnectToFailed connectToFailed(std::move(packet));
- connectToFailed.Read();
- HandleConnectToFailed(connectToFailed);
- break;
+ // locking just to safely log offending user is probably overkill but we are disconnecting him anyway
+ if (sessionGuard.try_lock())
+ TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_CONTINUED_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
+ return ReadDataHandlerResult::Error;
}
- default:
- {
- sessionGuard.lock();
- LogOpcodeText(opcode, sessionGuard);
+ std::shared_ptr<WorldPackets::Auth::AuthContinuedSession> authSession = std::make_shared<WorldPackets::Auth::AuthContinuedSession>(std::move(packet));
+ authSession->Read();
+ HandleAuthContinuedSession(authSession);
+ return ReadDataHandlerResult::WaitingForQuery;
+ }
+ case CMSG_KEEP_ALIVE:
+ LogOpcodeText(opcode, sessionGuard);
+ break;
+ case CMSG_LOG_DISCONNECT:
+ LogOpcodeText(opcode, sessionGuard);
+ packet.rfinish(); // contains uint32 disconnectReason;
+ break;
+ case CMSG_ENABLE_NAGLE:
+ LogOpcodeText(opcode, sessionGuard);
+ SetNoDelay(false);
+ break;
+ case CMSG_CONNECT_TO_FAILED:
+ {
+ sessionGuard.lock();
- if (!_worldSession)
- {
- TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode));
- return ReadDataHandlerResult::Error;
- }
+ LogOpcodeText(opcode, sessionGuard);
+ WorldPackets::Auth::ConnectToFailed connectToFailed(std::move(packet));
+ connectToFailed.Read();
+ HandleConnectToFailed(connectToFailed);
+ break;
+ }
+ default:
+ {
+ sessionGuard.lock();
- OpcodeHandler const* handler = opcodeTable[opcode];
- if (!handler)
- {
- TC_LOG_ERROR("network.opcode", "No defined handler for opcode %s sent by %s", GetOpcodeNameForLogging(static_cast<OpcodeClient>(packet.GetOpcode())).c_str(), _worldSession->GetPlayerInfo().c_str());
- break;
- }
+ LogOpcodeText(opcode, sessionGuard);
- // Our Idle timer will reset on any non PING opcodes.
- // Catches people idling on the login screen and any lingering ingame connections.
- _worldSession->ResetTimeOutTime();
+ if (!_worldSession)
+ {
+ TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode));
+ return ReadDataHandlerResult::Error;
+ }
- // Copy the packet to the heap before enqueuing
- _worldSession->QueuePacket(new WorldPacket(std::move(packet)));
+ OpcodeHandler const* handler = opcodeTable[opcode];
+ if (!handler)
+ {
+ TC_LOG_ERROR("network.opcode", "No defined handler for opcode %s sent by %s", GetOpcodeNameForLogging(static_cast<OpcodeClient>(packet.GetOpcode())).c_str(), _worldSession->GetPlayerInfo().c_str());
break;
}
- }
- }
- else
- {
- std::string initializer(reinterpret_cast<char const*>(_packetBuffer.GetReadPointer()), std::min(_packetBuffer.GetActiveSize(), ClientConnectionInitialize.length()));
- if (initializer != ClientConnectionInitialize)
- return ReadDataHandlerResult::Error;
-
- _compressionStream = new z_stream();
- _compressionStream->zalloc = (alloc_func)NULL;
- _compressionStream->zfree = (free_func)NULL;
- _compressionStream->opaque = (voidpf)NULL;
- _compressionStream->avail_in = 0;
- _compressionStream->next_in = NULL;
- int32 z_res = deflateInit2(_compressionStream, sWorld->getIntConfig(CONFIG_COMPRESSION), Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY);
- if (z_res != Z_OK)
- {
- TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: %i (%s)", z_res, zError(z_res));
- return ReadDataHandlerResult::Error;
- }
- _initialized = true;
- _headerBuffer.Resize(SizeOfClientHeader[1][0]);
- _packetBuffer.Reset();
- HandleSendAuthSession();
+ // Our Idle timer will reset on any non PING opcodes.
+ // Catches people idling on the login screen and any lingering ingame connections.
+ _worldSession->ResetTimeOutTime();
+
+ // Copy the packet to the heap before enqueuing
+ _worldSession->QueuePacket(new WorldPacket(std::move(packet)));
+ break;
+ }
}
return ReadDataHandlerResult::Ok;
@@ -610,7 +642,7 @@ struct AccountInfo
void WorldSocket::HandleAuthSession(std::shared_ptr<WorldPackets::Auth::AuthSession> authSession)
{
// Client switches packet headers after sending CMSG_AUTH_SESSION
- _headerBuffer.Resize(SizeOfClientHeader[1][1]);
+ _headerBuffer.Resize(SizeOfClientHeader[1]);
// Get the account information from the auth database
PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_ACCOUNT_INFO_BY_NAME);
@@ -811,7 +843,7 @@ void WorldSocket::HandleAuthContinuedSession(std::shared_ptr<WorldPackets::Auth:
}
// Client switches packet headers after sending CMSG_AUTH_CONTINUED_SESSION
- _headerBuffer.Resize(SizeOfClientHeader[1][1]);
+ _headerBuffer.Resize(SizeOfClientHeader[1]);
uint32 accountId = uint32(key.Fields.AccountId);
PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_ACCOUNT_INFO_CONTINUED_SESSION);
diff --git a/src/server/game/Server/WorldSocket.h b/src/server/game/Server/WorldSocket.h
index 205494ca4ea..0191d4d04d5 100644
--- a/src/server/game/Server/WorldSocket.h
+++ b/src/server/game/Server/WorldSocket.h
@@ -107,6 +107,7 @@ protected:
ReadDataHandlerResult ReadDataHandler();
private:
void CheckIpCallback(PreparedQueryResult result);
+ void InitializeHandler(boost::system::error_code error, std::size_t transferedBytes);
/// writes network.opcode log
/// accessing WorldSession is not threadsafe, only do it when holding _worldSessionLock
@@ -148,8 +149,6 @@ private:
z_stream_s* _compressionStream;
- bool _initialized;
-
PreparedQueryResultFuture _queryFuture;
std::function<void(PreparedQueryResult&&)> _queryCallback;
std::string _ipCountry;
diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h
index 07f427652aa..80655441ec2 100644
--- a/src/server/shared/Networking/Socket.h
+++ b/src/server/shared/Networking/Socket.h
@@ -90,6 +90,17 @@ public:
std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}
+ void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t))
+ {
+ if (!IsOpen())
+ return;
+
+ _readBuffer.Normalize();
+ _readBuffer.EnsureFreeSpace();
+ _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
+ std::bind(callback, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
+ }
+
void QueuePacket(MessageBuffer&& buffer)
{
_writeQueue.push(std::move(buffer));