diff options
-rw-r--r-- | src/server/authserver/Server/AuthSession.cpp | 240 | ||||
-rw-r--r-- | src/server/authserver/Server/AuthSession.h | 59 | ||||
-rw-r--r-- | src/server/game/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/server/game/Server/WorldSession.cpp | 2 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocket.cpp | 178 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocket.h | 25 | ||||
-rw-r--r-- | src/server/shared/Networking/Socket.h | 109 |
7 files changed, 344 insertions, 270 deletions
diff --git a/src/server/authserver/Server/AuthSession.cpp b/src/server/authserver/Server/AuthSession.cpp index 913e624635c..38198854bfc 100644 --- a/src/server/authserver/Server/AuthSession.cpp +++ b/src/server/authserver/Server/AuthSession.cpp @@ -111,98 +111,88 @@ typedef struct AUTH_RECONNECT_PROOF_C #pragma pack(pop) - -typedef struct AuthHandler -{ - eAuthCmd cmd; - uint32 status; - size_t packetSize; - bool (AuthSession::*handler)(); -} AuthHandler; - #define BYTE_SIZE 32 #define REALMLIST_SKIP_PACKETS 5 +#define XFER_ACCEPT_SIZE 1 +#define XFER_RESUME_SIZE 9 +#define XFER_CANCEL_SIZE 1 -const AuthHandler table[] = +std::unordered_map<uint8, AuthHandler> AuthSession::InitHandlers() { - { AUTH_LOGON_CHALLENGE, STATUS_CONNECTED, sizeof(AUTH_LOGON_CHALLENGE_C), &AuthSession::_HandleLogonChallenge }, - { AUTH_LOGON_PROOF, STATUS_CONNECTED, sizeof(AUTH_LOGON_PROOF_C), &AuthSession::_HandleLogonProof }, - { AUTH_RECONNECT_CHALLENGE, STATUS_CONNECTED, sizeof(AUTH_LOGON_CHALLENGE_C), &AuthSession::_HandleReconnectChallenge }, - { AUTH_RECONNECT_PROOF, STATUS_CONNECTED, sizeof(AUTH_RECONNECT_PROOF_C), &AuthSession::_HandleReconnectProof }, - { REALM_LIST, STATUS_AUTHED, REALMLIST_SKIP_PACKETS, &AuthSession::_HandleRealmList } -}; + std::unordered_map<uint8, AuthHandler> handlers; + + handlers[AUTH_LOGON_CHALLENGE] = { STATUS_CONNECTED, sizeof(AUTH_LOGON_CHALLENGE_C), &AuthSession::HandleLogonChallenge }; + handlers[AUTH_LOGON_PROOF] = { STATUS_CONNECTED, sizeof(AUTH_LOGON_PROOF_C), &AuthSession::HandleLogonProof }; + handlers[AUTH_RECONNECT_CHALLENGE] = { STATUS_CONNECTED, sizeof(AUTH_LOGON_CHALLENGE_C), &AuthSession::HandleReconnectChallenge }; + handlers[AUTH_RECONNECT_PROOF] = { STATUS_CONNECTED, sizeof(AUTH_RECONNECT_PROOF_C), &AuthSession::HandleReconnectProof }; + handlers[REALM_LIST] = { STATUS_AUTHED, REALMLIST_SKIP_PACKETS, &AuthSession::HandleRealmList }; + handlers[XFER_ACCEPT] = { STATUS_AUTHED, XFER_ACCEPT_SIZE, &AuthSession::HandleXferAccept }; + handlers[XFER_RESUME] = { STATUS_AUTHED, XFER_RESUME_SIZE, &AuthSession::HandleXferResume }; + handlers[XFER_CANCEL] = { STATUS_AUTHED, XFER_CANCEL_SIZE, &AuthSession::HandleXferCancel }; + + return handlers; +} -void AuthSession::AsyncReadHeader() -{ - auto self(shared_from_this()); +std::unordered_map<uint8, AuthHandler> const Handlers = AuthSession::InitHandlers(); - _socket.async_read_some(boost::asio::buffer(_readBuffer, 1), [this, self](boost::system::error_code error, size_t transferedBytes) +void AuthSession::ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) +{ + if (!error && transferedBytes == 1) { - if (!error && transferedBytes == 1) + uint8 cmd = GetReadBuffer()[0]; + auto itr = Handlers.find(cmd); + if (itr != Handlers.end()) { - for (const AuthHandler& entry : table) + // Handle dynamic size packet + if (cmd == AUTH_LOGON_CHALLENGE || cmd == AUTH_RECONNECT_CHALLENGE) { - if ((uint8)entry.cmd == _readBuffer[0] && (entry.status == STATUS_CONNECTED || (_isAuthenticated && entry.status == STATUS_AUTHED))) - { - // Handle dynamic size packet - if (_readBuffer[0] == AUTH_LOGON_CHALLENGE || _readBuffer[0] == AUTH_RECONNECT_CHALLENGE) - { - _socket.read_some(boost::asio::buffer(&_readBuffer[1], sizeof(uint8) + sizeof(uint16))); //error + size + ReadData(sizeof(uint8) + sizeof(uint16), sizeof(cmd)); //error + size + sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer()); - AsyncReadData(entry.handler, *reinterpret_cast<uint16*>(&_readBuffer[2]), sizeof(uint8) + sizeof(uint8) + sizeof(uint16)); // cmd + error + size - } - else - { - AsyncReadData(entry.handler, entry.packetSize, sizeof(uint8)); - } - break; - } + AsyncReadData(challenge->size, sizeof(uint8) + sizeof(uint8) + sizeof(uint16)); // cmd + error + size } + else + AsyncReadData(itr->second.packetSize, sizeof(uint8)); } - else - { - CloseSocket(); - } - }); + } + else + CloseSocket(); } -void AuthSession::AsyncReadData(bool (AuthSession::*handler)(), size_t dataSize, size_t bufferOffSet) +void AuthSession::ReadDataHandler(boost::system::error_code error, size_t transferedBytes) { - auto self(shared_from_this()); - - _socket.async_read_some(boost::asio::buffer(&_readBuffer[bufferOffSet], dataSize), [handler, this, self](boost::system::error_code error, size_t transferedBytes) + if (!error && transferedBytes > 0) { - if (!error && transferedBytes > 0) - { - if (!(*this.*handler)()) - { - CloseSocket(); - return; - } - - AsyncReadHeader(); - } - else + if (!(*this.*Handlers.at(GetReadBuffer()[0]).handler)()) { CloseSocket(); + return; } - }); + + AsyncReadHeader(); + } + else + CloseSocket(); } -void AuthSession::AsyncWrite(std::size_t length) +void AuthSession::AsyncWrite(ByteBuffer const& packet) { - boost::asio::async_write(_socket, boost::asio::buffer(_writeBuffer, length), [this](boost::system::error_code error, std::size_t /*length*/) - { - if (error) - { - CloseSocket(); - } - }); + std::vector<uint8> data(packet.size()); + std::memcpy(data.data(), packet.contents(), packet.size()); + + std::lock_guard<std::mutex> guard(_writeLock); + + bool needsWriteStart = _writeQueue.empty(); + + _writeQueue.push(std::move(data)); + + if (needsWriteStart) + AsyncWrite(_writeQueue.front()); } -bool AuthSession::_HandleLogonChallenge() +bool AuthSession::HandleLogonChallenge() { - sAuthLogonChallenge_C *challenge = (sAuthLogonChallenge_C*)&_readBuffer; + sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer()); //TC_LOG_DEBUG("server.authserver", "[AuthChallenge] got full packet, %#04x bytes", challenge->size); TC_LOG_DEBUG("server.authserver", "[AuthChallenge] name(%d): '%s'", challenge->I_len, challenge->I); @@ -226,8 +216,8 @@ bool AuthSession::_HandleLogonChallenge() // Verify that this IP is not in the ip_banned table LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS)); - std::string const& ipAddress = _socket.remote_endpoint().address().to_string(); - unsigned short port = _socket.remote_endpoint().port(); + std::string ipAddress = GetRemoteIpAddress().to_string(); + uint16 port = GetRemotePort(); PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_BANNED); stmt->setString(0, ipAddress); @@ -411,20 +401,17 @@ bool AuthSession::_HandleLogonChallenge() pkt << uint8(WOW_FAIL_UNKNOWN_ACCOUNT); } - std::memcpy(_writeBuffer, (char const*)pkt.contents(), pkt.size()); - - AsyncWrite(pkt.size()); - + AsyncWrite(pkt); return true; } // Logon Proof command handler -bool AuthSession::_HandleLogonProof() +bool AuthSession::HandleLogonProof() { TC_LOG_DEBUG("server.authserver", "Entering _HandleLogonProof"); // Read the packet - sAuthLogonProof_C *logonProof = (sAuthLogonProof_C*)&_readBuffer; + sAuthLogonProof_C *logonProof = reinterpret_cast<sAuthLogonProof_C*>(GetReadBuffer()); // If the client has no valid version if (_expversion == NO_VALID_EXP_FLAG) @@ -512,7 +499,7 @@ bool AuthSession::_HandleLogonProof() // Check if SRP6 results match (password is correct), else send an error if (!memcmp(M.AsByteArray().get(), logonProof->M1, 20)) { - TC_LOG_DEBUG("server.authserver", "'%s:%d' User '%s' successfully authenticated", GetRemoteIpAddress().c_str(), GetRemotePort(), _login.c_str()); + TC_LOG_DEBUG("server.authserver", "'%s:%d' User '%s' successfully authenticated", GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str()); // Update the sessionkey, last_ip, last login time and reset number of failed logins in the account table for this account // No SQL injection (escaped user name) and IP address as received by socket @@ -520,7 +507,7 @@ bool AuthSession::_HandleLogonProof() PreparedStatement *stmt = LoginDatabase.GetPreparedStatement(LOGIN_UPD_LOGONPROOF); stmt->setString(0, K_hex); - stmt->setString(1, GetRemoteIpAddress().c_str()); + stmt->setString(1, GetRemoteIpAddress().to_string().c_str()); stmt->setUInt32(2, GetLocaleByName(_localizationName)); stmt->setString(3, _os); stmt->setString(4, _login); @@ -549,12 +536,17 @@ bool AuthSession::_HandleLogonProof() delete[] token; if (validToken != incomingToken) { - char data[] = { AUTH_LOGON_PROOF, WOW_FAIL_UNKNOWN_ACCOUNT, 3, 0 }; - socket().send(data, sizeof(data)); - return false; + ByteBuffer packet; + packet << uint8(AUTH_LOGON_PROOF); + packet << uint8(WOW_FAIL_UNKNOWN_ACCOUNT); + packet << uint8(3); + packet << uint8(0); + AsyncWrite(packet); + return false; }*/ } + ByteBuffer packet; if (_expversion & POST_BC_EXP_FLAG) // 2.x and 3.x clients { sAuthLogonProof_S proof; @@ -565,8 +557,8 @@ bool AuthSession::_HandleLogonProof() proof.unk2 = 0x00; // SurveyId proof.unk3 = 0x00; - std::memcpy(_writeBuffer, (char *)&proof, sizeof(proof)); - AsyncWrite(sizeof(proof)); + packet.resize(sizeof(proof)); + std::memcpy(packet.contents(), &proof, sizeof(proof)); } else { @@ -576,21 +568,24 @@ bool AuthSession::_HandleLogonProof() proof.error = 0; proof.unk2 = 0x00; - std::memcpy(_writeBuffer, (char *)&proof, sizeof(proof)); - AsyncWrite(sizeof(proof)); + packet.resize(sizeof(proof)); + std::memcpy(packet.contents(), &proof, sizeof(proof)); } + AsyncWrite(packet); _isAuthenticated = true; } else { - char data[4] = { AUTH_LOGON_PROOF, WOW_FAIL_UNKNOWN_ACCOUNT, 3, 0 }; - - std::memcpy(_writeBuffer, data, sizeof(data)); - AsyncWrite(sizeof(data)); + ByteBuffer packet; + packet << uint8(AUTH_LOGON_PROOF); + packet << uint8(WOW_FAIL_UNKNOWN_ACCOUNT); + packet << uint8(3); + packet << uint8(0); + AsyncWrite(packet); TC_LOG_DEBUG("server.authserver", "'%s:%d' [AuthChallenge] account %s tried to login with invalid password!", - GetRemoteIpAddress().c_str(), GetRemotePort(), _login.c_str()); + GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str()); uint32 MaxWrongPassCount = sConfigMgr->GetIntDefault("WrongPass.MaxCount", 0); @@ -599,7 +594,7 @@ bool AuthSession::_HandleLogonProof() { PreparedStatement* logstmt = LoginDatabase.GetPreparedStatement(LOGIN_INS_FALP_IP_LOGGING); logstmt->setString(0, _login); - logstmt->setString(1, GetRemoteIpAddress()); + logstmt->setString(1, GetRemoteIpAddress().to_string()); logstmt->setString(2, "Logged on failed AccountLogin due wrong password"); LoginDatabase.Execute(logstmt); @@ -633,17 +628,17 @@ bool AuthSession::_HandleLogonProof() LoginDatabase.Execute(stmt); TC_LOG_DEBUG("server.authserver", "'%s:%d' [AuthChallenge] account %s got banned for '%u' seconds because it failed to authenticate '%u' times", - GetRemoteIpAddress().c_str(), GetRemotePort(), _login.c_str(), WrongPassBanTime, failed_logins); + GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str(), WrongPassBanTime, failed_logins); } else { stmt = LoginDatabase.GetPreparedStatement(LOGIN_INS_IP_AUTO_BANNED); - stmt->setString(0, GetRemoteIpAddress()); + stmt->setString(0, GetRemoteIpAddress().to_string()); stmt->setUInt32(1, WrongPassBanTime); LoginDatabase.Execute(stmt); TC_LOG_DEBUG("server.authserver", "'%s:%d' [AuthChallenge] IP got banned for '%u' seconds because account %s failed to authenticate '%u' times", - GetRemoteIpAddress().c_str(), GetRemotePort(), WrongPassBanTime, _login.c_str(), failed_logins); + GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), WrongPassBanTime, _login.c_str(), failed_logins); } } } @@ -653,10 +648,10 @@ bool AuthSession::_HandleLogonProof() return true; } -bool AuthSession::_HandleReconnectChallenge() +bool AuthSession::HandleReconnectChallenge() { TC_LOG_DEBUG("server.authserver", "Entering _HandleReconnectChallenge"); - sAuthLogonChallenge_C *challenge = (sAuthLogonChallenge_C*)&_readBuffer; + sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer()); //TC_LOG_DEBUG("server.authserver", "[AuthChallenge] got full packet, %#04x bytes", challenge->size); TC_LOG_DEBUG("server.authserver", "[AuthChallenge] name(%d): '%s'", challenge->I_len, challenge->I); @@ -671,7 +666,7 @@ bool AuthSession::_HandleReconnectChallenge() if (!result) { TC_LOG_ERROR("server.authserver", "'%s:%d' [ERROR] user %s tried to login and we cannot find his session key in the database.", - GetRemoteIpAddress().c_str(), GetRemotePort(), _login.c_str()); + GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str()); return false; } @@ -700,15 +695,14 @@ bool AuthSession::_HandleReconnectChallenge() pkt.append(_reconnectProof.AsByteArray(16).get(), 16); // 16 bytes random pkt << uint64(0x00) << uint64(0x00); // 16 bytes zeros - std::memcpy(_writeBuffer, (char const*)pkt.contents(), pkt.size()); - AsyncWrite(pkt.size()); + AsyncWrite(pkt); return true; } -bool AuthSession::_HandleReconnectProof() +bool AuthSession::HandleReconnectProof() { TC_LOG_DEBUG("server.authserver", "Entering _HandleReconnectProof"); - sAuthReconnectProof_C *reconnectProof = (sAuthReconnectProof_C*)&_readBuffer; + sAuthReconnectProof_C *reconnectProof = reinterpret_cast<sAuthReconnectProof_C*>(GetReadBuffer()); if (_login.empty() || !_reconnectProof.GetNumBytes() || !K.GetNumBytes()) return false; @@ -729,14 +723,13 @@ bool AuthSession::_HandleReconnectProof() pkt << uint8(AUTH_RECONNECT_PROOF); pkt << uint8(0x00); pkt << uint16(0x00); // 2 bytes zeros - std::memcpy(_writeBuffer, (char const*)pkt.contents(), pkt.size()); - AsyncWrite(pkt.size()); + AsyncWrite(pkt); _isAuthenticated = true; return true; } else { - TC_LOG_ERROR("server.authserver", "'%s:%d' [ERROR] user %s tried to login, but session is invalid.", GetRemoteIpAddress().c_str(), + TC_LOG_ERROR("server.authserver", "'%s:%d' [ERROR] user %s tried to login, but session is invalid.", GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str()); return false; } @@ -777,7 +770,7 @@ tcp::endpoint const GetAddressForClient(Realm const& realm, ip::address const& c return endpoint; } -bool AuthSession::_HandleRealmList() +bool AuthSession::HandleRealmList() { TC_LOG_DEBUG("server.authserver", "Entering _HandleRealmList"); @@ -788,7 +781,7 @@ bool AuthSession::_HandleRealmList() PreparedQueryResult result = LoginDatabase.Query(stmt); if (!result) { - TC_LOG_ERROR("server.authserver", "'%s:%d' [ERROR] user %s tried to login but we cannot find him in the database.", GetRemoteIpAddress().c_str(), + TC_LOG_ERROR("server.authserver", "'%s:%d' [ERROR] user %s tried to login but we cannot find him in the database.", GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str()); return false; } @@ -846,7 +839,7 @@ bool AuthSession::_HandleRealmList() pkt << lock; // if 1, then realm locked pkt << uint8(flag); // RealmFlags pkt << name; - pkt << boost::lexical_cast<std::string>(GetAddressForClient(realm, _socket.remote_endpoint().address())); + pkt << boost::lexical_cast<std::string>(GetAddressForClient(realm, GetRemoteIpAddress())); pkt << realm.populationLevel; pkt << AmountOfCharacters; pkt << realm.timezone; // realm category @@ -890,10 +883,32 @@ bool AuthSession::_HandleRealmList() hdr << uint16(pkt.size() + RealmListSizeBuffer.size()); hdr.append(RealmListSizeBuffer); // append RealmList's size buffer hdr.append(pkt); // append realms in the realmlist + AsyncWrite(hdr); + return true; +} - std::memcpy(_writeBuffer, (char const*)hdr.contents(), hdr.size()); - AsyncWrite(hdr.size()); +// Resume patch transfer +bool AuthSession::HandleXferResume() +{ + TC_LOG_DEBUG("server.authserver", "Entering _HandleXferResume"); + //uint8 + //uint64 + return true; +} +// Cancel patch transfer +bool AuthSession::HandleXferCancel() +{ + TC_LOG_DEBUG("server.authserver", "Entering _HandleXferCancel"); + //uint8 + return false; +} + +// Accept patch transfer +bool AuthSession::HandleXferAccept() +{ + TC_LOG_DEBUG("server.authserver", "Entering _HandleXferAccept"); + //uint8 return true; } @@ -935,12 +950,3 @@ void AuthSession::SetVSFields(const std::string& rI) OPENSSL_free(v_hex); OPENSSL_free(s_hex); } - -void AuthSession::CloseSocket() -{ - boost::system::error_code socketError; - _socket.close(socketError); - if (socketError) - TC_LOG_DEBUG("server.authserver", "Account '%s' errored when closing socket: %i (%s)", - _login.c_str(), socketError.value(), socketError.message().c_str()); -} diff --git a/src/server/authserver/Server/AuthSession.h b/src/server/authserver/Server/AuthSession.h index 6dc9c404857..eedffb86ff8 100644 --- a/src/server/authserver/Server/AuthSession.h +++ b/src/server/authserver/Server/AuthSession.h @@ -19,46 +19,52 @@ #ifndef __AUTHSESSION_H__ #define __AUTHSESSION_H__ -#include <memory> -#include <boost/asio/ip/tcp.hpp> #include "Common.h" +#include "Socket.h" #include "BigNumber.h" +#include <memory> +#include <boost/asio/ip/tcp.hpp> using boost::asio::ip::tcp; -const size_t bufferSize = 4096; +struct AuthHandler; +class ByteBuffer; -#define BUFFER_SIZE 4096 - -class AuthSession : public std::enable_shared_from_this < AuthSession > +class AuthSession : public Socket<AuthSession> { + public: - AuthSession(tcp::socket&& socket) : _socket(std::move(socket)) + static std::unordered_map<uint8, AuthHandler> InitHandlers(); + + AuthSession(tcp::socket&& socket) : Socket(std::move(socket), 1) { N.SetHexStr("894B645E89E1535BBDAD5B8B290650530801B18EBFBF5E8FAB3C82872A3E9BB7"); g.SetDword(7); } - void Start() + void Start() override { AsyncReadHeader(); } - bool _HandleLogonChallenge(); - bool _HandleLogonProof(); - bool _HandleReconnectChallenge(); - bool _HandleReconnectProof(); - bool _HandleRealmList(); + using Socket<AuthSession>::AsyncWrite; + void AsyncWrite(ByteBuffer const& packet); - const std::string GetRemoteIpAddress() const { return _socket.remote_endpoint().address().to_string(); }; - unsigned short GetRemotePort() const { return _socket.remote_endpoint().port(); } +protected: + void ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) override; + void ReadDataHandler(boost::system::error_code error, size_t transferedBytes) override; private: - void AsyncReadHeader(); - void AsyncReadData(bool (AuthSession::*handler)(), size_t dataSize, size_t bufferOffset); - void AsyncWrite(size_t length); + bool HandleLogonChallenge(); + bool HandleLogonProof(); + bool HandleReconnectChallenge(); + bool HandleReconnectProof(); + bool HandleRealmList(); - void CloseSocket(); + //data transfer handle for patch + bool HandleXferResume(); + bool HandleXferCancel(); + bool HandleXferAccept(); void SetVSFields(const std::string& rI); @@ -67,10 +73,6 @@ private: BigNumber K; BigNumber _reconnectProof; - tcp::socket _socket; - char _readBuffer[BUFFER_SIZE]; - char _writeBuffer[BUFFER_SIZE]; - bool _isAuthenticated; std::string _tokenKey; std::string _login; @@ -82,4 +84,15 @@ private: AccountTypes _accountSecurityLevel; }; +#pragma pack(push, 1) + +struct AuthHandler +{ + uint32 status; + size_t packetSize; + bool (AuthSession::*handler)(); +}; + +#pragma pack(pop) + #endif diff --git a/src/server/game/CMakeLists.txt b/src/server/game/CMakeLists.txt index 84d6f87d75c..532900c0438 100644 --- a/src/server/game/CMakeLists.txt +++ b/src/server/game/CMakeLists.txt @@ -123,6 +123,7 @@ include_directories( ${CMAKE_SOURCE_DIR}/src/server/shared/Dynamic/LinkedReference ${CMAKE_SOURCE_DIR}/src/server/shared/Dynamic ${CMAKE_SOURCE_DIR}/src/server/shared/Logging + ${CMAKE_SOURCE_DIR}/src/server/shared/Networking ${CMAKE_SOURCE_DIR}/src/server/shared/Packets ${CMAKE_SOURCE_DIR}/src/server/shared/Threading ${CMAKE_SOURCE_DIR}/src/server/shared/Utilities diff --git a/src/server/game/Server/WorldSession.cpp b/src/server/game/Server/WorldSession.cpp index f30991b385e..eb83a5f55be 100644 --- a/src/server/game/Server/WorldSession.cpp +++ b/src/server/game/Server/WorldSession.cpp @@ -130,7 +130,7 @@ WorldSession::WorldSession(uint32 id, std::shared_ptr<WorldSocket> sock, Account if (sock) { - m_Address = sock->GetRemoteIpAddress(); + m_Address = sock->GetRemoteIpAddress().to_string(); ResetTimeOutTime(); LoginDatabase.PExecute("UPDATE account SET online = 1 WHERE id = %u;", GetAccountId()); // One-time query } diff --git a/src/server/game/Server/WorldSocket.cpp b/src/server/game/Server/WorldSocket.cpp index 575c145687c..87340593c79 100644 --- a/src/server/game/Server/WorldSocket.cpp +++ b/src/server/game/Server/WorldSocket.cpp @@ -31,7 +31,7 @@ using boost::asio::ip::tcp; using boost::asio::streambuf; WorldSocket::WorldSocket(tcp::socket&& socket) - : _socket(std::move(socket)), _authSeed(rand32()), _OverSpeedPings(0), _worldSession(nullptr) + : Socket(std::move(socket), sizeof(ClientPktHeader)), _authSeed(rand32()), _OverSpeedPings(0), _worldSession(nullptr) { } @@ -58,105 +58,89 @@ void WorldSocket::HandleSendAuthSession() AsyncWrite(packet); } -void WorldSocket::AsyncReadHeader() +void WorldSocket::ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) { - auto self(shared_from_this()); - _socket.async_read_some(boost::asio::buffer(_readBuffer, sizeof(ClientPktHeader)), [this, self](boost::system::error_code error, size_t transferedBytes) + if (!error && transferedBytes == sizeof(ClientPktHeader)) { - if (!error && transferedBytes == sizeof(ClientPktHeader)) - { - ClientPktHeader* header = (ClientPktHeader*)&_readBuffer; - - if (_worldSession) - _authCrypt.DecryptRecv((uint8*)header, sizeof(ClientPktHeader)); + _authCrypt.DecryptRecv(GetReadBuffer(), sizeof(ClientPktHeader)); - EndianConvertReverse(header->size); - EndianConvert(header->cmd); + ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(GetReadBuffer()); + EndianConvertReverse(header->size); + EndianConvert(header->cmd); - AsyncReadData(header->size - sizeof(header->cmd)); - } - else - { - // _socket.is_open() till returns true even after calling close() - CloseSocket(); - } - }); + AsyncReadData(header->size - sizeof(header->cmd), sizeof(ClientPktHeader)); + } + else + CloseSocket(); } -void WorldSocket::AsyncReadData(size_t dataSize) +void WorldSocket::ReadDataHandler(boost::system::error_code error, size_t transferedBytes) { - auto self(shared_from_this()); - _socket.async_read_some(boost::asio::buffer(&_readBuffer[sizeof(ClientPktHeader)], dataSize), [this, dataSize, self](boost::system::error_code error, size_t transferedBytes) + ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(GetReadBuffer()); + + if (!error && transferedBytes == (header->size - sizeof(header->cmd))) { - if (!error && transferedBytes == dataSize) - { - ClientPktHeader* header = (ClientPktHeader*)&_readBuffer; + header->size -= sizeof(header->cmd); - header->size -= sizeof(header->cmd); + uint16 opcode = uint16(header->cmd); - uint16 opcode = (uint16)header->cmd; + std::string opcodeName = GetOpcodeNameForLogging(opcode); - std::string opcodeName = GetOpcodeNameForLogging(opcode); + WorldPacket packet(opcode, header->size); - WorldPacket packet(opcode, header->size); + if (header->size > 0) + { + packet.resize(header->size); - if (header->size > 0) - { - packet.resize(header->size); + std::memcpy(packet.contents(), &(GetReadBuffer()[sizeof(ClientPktHeader)]), header->size); + } - std::memcpy(packet.contents(), &_readBuffer[sizeof(ClientPktHeader)], header->size); - } + if (sPacketLog->CanLogPacket()) + sPacketLog->LogPacket(packet, CLIENT_TO_SERVER); - if (sPacketLog->CanLogPacket()) - sPacketLog->LogPacket(packet, CLIENT_TO_SERVER); + TC_LOG_TRACE("network.opcode", "C->S: %s %s", (_worldSession ? _worldSession->GetPlayerInfo() : GetRemoteIpAddress().to_string()).c_str(), GetOpcodeNameForLogging(opcode).c_str()); - TC_LOG_TRACE("network.opcode", "C->S: %s %s", (_worldSession ? _worldSession->GetPlayerInfo() : GetRemoteIpAddress()).c_str(), GetOpcodeNameForLogging(opcode).c_str()); + switch (opcode) + { + case CMSG_PING: + HandlePing(packet); + break; + case CMSG_AUTH_SESSION: + if (_worldSession) + { + TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str()); + break; + } - switch (opcode) + sScriptMgr->OnPacketReceive(shared_from_this(), packet); + HandleAuthSession(packet); + break; + case CMSG_KEEP_ALIVE: + TC_LOG_DEBUG("network", "%s", opcodeName.c_str()); + sScriptMgr->OnPacketReceive(shared_from_this(), packet); + break; + default: { - case CMSG_PING: - HandlePing(packet); - break; - case CMSG_AUTH_SESSION: - if (_worldSession) - { - TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str()); - break; - } - - sScriptMgr->OnPacketReceive(shared_from_this(), packet); - HandleAuthSession(packet); - break; - case CMSG_KEEP_ALIVE: - TC_LOG_DEBUG("network", "%s", opcodeName.c_str()); - sScriptMgr->OnPacketReceive(shared_from_this(), packet); - break; - default: + if (!_worldSession) { - if (!_worldSession) - { - TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode)); - break; - } - - // Our Idle timer will reset on any non PING opcodes. - // Catches people idling on the login screen and any lingering ingame connections. - _worldSession->ResetTimeOutTime(); - - // Copy the packet to the heap before enqueuing - _worldSession->QueuePacket(new WorldPacket(packet)); + TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode)); break; } - } - AsyncReadHeader(); - } - else - { - // _socket.is_open() till returns true even after calling close() - CloseSocket(); + // Our Idle timer will reset on any non PING opcodes. + // Catches people idling on the login screen and any lingering ingame connections. + _worldSession->ResetTimeOutTime(); + + // Copy the packet to the heap before enqueuing + _worldSession->QueuePacket(new WorldPacket(packet)); + break; + } } - }); + + AsyncReadHeader(); + } + else + CloseSocket(); } void WorldSocket::AsyncWrite(WorldPacket const& packet) @@ -164,7 +148,7 @@ void WorldSocket::AsyncWrite(WorldPacket const& packet) if (sPacketLog->CanLogPacket()) sPacketLog->LogPacket(packet, SERVER_TO_CLIENT); - TC_LOG_TRACE("network.opcode", "S->C: %s %s", (_worldSession ? _worldSession->GetPlayerInfo() : GetRemoteIpAddress()).c_str(), GetOpcodeNameForLogging(packet.GetOpcode()).c_str()); + TC_LOG_TRACE("network.opcode", "S->C: %s %s", (_worldSession ? _worldSession->GetPlayerInfo() : GetRemoteIpAddress().to_string()).c_str(), GetOpcodeNameForLogging(packet.GetOpcode()).c_str()); ServerPktHeader header(packet.size() + 2, packet.GetOpcode()); @@ -185,25 +169,6 @@ void WorldSocket::AsyncWrite(WorldPacket const& packet) AsyncWrite(_writeQueue.front()); } -void WorldSocket::AsyncWrite(std::vector<uint8> const& data) -{ - auto self(shared_from_this()); - boost::asio::async_write(_socket, boost::asio::buffer(data), [this, self](boost::system::error_code error, std::size_t /*length*/) - { - if (!error) - { - std::lock_guard<std::mutex> deleteGuard(_writeLock); - - _writeQueue.pop(); - - if (!_writeQueue.empty()) - AsyncWrite(_writeQueue.front()); - } - else - CloseSocket(); - }); -} - void WorldSocket::HandleAuthSession(WorldPacket& recvPacket) { uint8 digest[20]; @@ -223,7 +188,7 @@ void WorldSocket::HandleAuthSession(WorldPacket& recvPacket) if (sWorld->IsClosed()) { SendAuthResponseError(AUTH_REJECT); - TC_LOG_ERROR("network", "WorldSocket::HandleAuthSession: World closed, denying client (%s).", GetRemoteIpAddress().c_str()); + TC_LOG_ERROR("network", "WorldSocket::HandleAuthSession: World closed, denying client (%s).", GetRemoteIpAddress().to_string().c_str()); return; } @@ -270,7 +235,7 @@ void WorldSocket::HandleAuthSession(WorldPacket& recvPacket) expansion = world_expansion; // For hook purposes, we get Remoteaddress at this point. - std::string address = GetRemoteIpAddress(); + std::string address = GetRemoteIpAddress().to_string(); // As we don't know if attempted login process by ip works, we update last_attempt_ip right away stmt = LoginDatabase.GetPreparedStatement(LOGIN_UPD_LAST_ATTEMPT_IP); @@ -471,7 +436,7 @@ void WorldSocket::HandlePing(WorldPacket& recvPacket) if (_worldSession && !_worldSession->HasPermission(rbac::RBAC_PERM_SKIP_CHECK_OVERSPEED_PING)) { TC_LOG_ERROR("network", "WorldSocket::HandlePing: %s kicked for over-speed pings (address: %s)", - _worldSession->GetPlayerInfo().c_str(), GetRemoteIpAddress().c_str()); + _worldSession->GetPlayerInfo().c_str(), GetRemoteIpAddress().to_string().c_str()); CloseSocket(); return; @@ -489,8 +454,7 @@ void WorldSocket::HandlePing(WorldPacket& recvPacket) } else { - TC_LOG_ERROR("network", "WorldSocket::HandlePing: peer sent CMSG_PING, but is not authenticated or got recently kicked, address = %s", - GetRemoteIpAddress().c_str()); + TC_LOG_ERROR("network", "WorldSocket::HandlePing: peer sent CMSG_PING, but is not authenticated or got recently kicked, address = %s", GetRemoteIpAddress().to_string().c_str()); CloseSocket(); return; @@ -500,13 +464,3 @@ void WorldSocket::HandlePing(WorldPacket& recvPacket) packet << ping; return AsyncWrite(packet); } - -void WorldSocket::CloseSocket() -{ - boost::system::error_code socketError; - _socket.close(socketError); - if (socketError) - TC_LOG_DEBUG("network", "WorldSocket::CloseSocket: Player '%s' (%s) errored when closing socket: %i (%s)", - _worldSession ? _worldSession->GetPlayerInfo().c_str() : "unknown", GetRemoteIpAddress().c_str(), - socketError.value(), socketError.message().c_str()); -} diff --git a/src/server/game/Server/WorldSocket.h b/src/server/game/Server/WorldSocket.h index 0f3fc553872..5839c2194c4 100644 --- a/src/server/game/Server/WorldSocket.h +++ b/src/server/game/Server/WorldSocket.h @@ -21,6 +21,7 @@ #include "Common.h" #include "AuthCrypt.h" +#include "Socket.h" #include "Util.h" #include "WorldPacket.h" #include "WorldSession.h" @@ -42,7 +43,7 @@ struct ClientPktHeader #pragma pack(pop) -class WorldSocket : public std::enable_shared_from_this<WorldSocket> +class WorldSocket : public Socket<WorldSocket> { public: WorldSocket(tcp::socket&& socket); @@ -50,15 +51,15 @@ public: WorldSocket(WorldSocket const& right) = delete; WorldSocket& operator=(WorldSocket const& right) = delete; - void Start(); + void Start() override; - std::string GetRemoteIpAddress() const { return _socket.remote_endpoint().address().to_string(); }; - uint16 GetRemotePort() const { return _socket.remote_endpoint().port(); } + void AsyncWrite(WorldPacket const& packet); - void CloseSocket(); - bool IsOpen() const { return _socket.is_open(); } + using Socket<WorldSocket>::AsyncWrite; - void AsyncWrite(WorldPacket const& packet); +protected: + void ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) override; + void ReadDataHandler(boost::system::error_code error, size_t transferedBytes) override; private: void HandleSendAuthSession(); @@ -67,16 +68,6 @@ private: void HandlePing(WorldPacket& recvPacket); - void AsyncReadHeader(); - void AsyncReadData(size_t dataSize); - void AsyncWrite(std::vector<uint8> const& data); - - tcp::socket _socket; - - char _readBuffer[4096]; - std::mutex _writeLock; - std::queue<std::vector<uint8> > _writeQueue; - uint32 _authSeed; AuthCrypt _authCrypt; diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h new file mode 100644 index 00000000000..daefa0d4ad5 --- /dev/null +++ b/src/server/shared/Networking/Socket.h @@ -0,0 +1,109 @@ +/* + * Copyright (C) 2008-2014 TrinityCore <http://www.trinitycore.org/> + * Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef __SOCKET_H__ +#define __SOCKET_H__ + +#include "Define.h" +#include "Log.h" +#include <vector> +#include <mutex> +#include <queue> +#include <memory> +#include <boost/asio/ip/tcp.hpp> +#include <boost/asio/placeholders.hpp> +#include <boost/bind.hpp> + +using boost::asio::ip::tcp; + +template<class T> +class Socket : public std::enable_shared_from_this<T> +{ +public: + Socket(tcp::socket&& socket, std::size_t headerSize) : _socket(std::move(socket)), _headerSize(headerSize) { } + + virtual void Start() = 0; + + boost::asio::ip::address GetRemoteIpAddress() const { return _socket.remote_endpoint().address(); }; + uint16 GetRemotePort() const { return _socket.remote_endpoint().port(); } + + void AsyncReadHeader() + { + _socket.async_read_some(boost::asio::buffer(_readBuffer, _headerSize), boost::bind(&Socket::ReadHeaderHandlerInternal, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); + } + + void AsyncReadData(std::size_t size, std::size_t bufferOffset) + { + _socket.async_read_some(boost::asio::buffer(&_readBuffer[bufferOffset], size), boost::bind(&Socket::ReadDataHandlerInternal, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); + } + + void ReadData(std::size_t size, std::size_t bufferOffset) + { + _socket.read_some(boost::asio::buffer(&_readBuffer[bufferOffset], size)); + } + + void AsyncWrite(std::vector<uint8> const& data) + { + boost::asio::async_write(_socket, boost::asio::buffer(data), boost::bind(&Socket::WriteHandler, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); + } + + bool IsOpen() const { return _socket.is_open(); } + void CloseSocket() + { + boost::system::error_code socketError; + _socket.close(socketError); + if (socketError) + TC_LOG_DEBUG("network", "Socket::CloseSocket: %s errored when closing socket: %i (%s)", GetRemoteIpAddress().to_string().c_str(), socketError.value(), socketError.message().c_str()); + } + + uint8* GetReadBuffer() { return _readBuffer; } + +protected: + virtual void ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) = 0; + virtual void ReadDataHandler(boost::system::error_code error, size_t transferedBytes) = 0; + + std::mutex _writeLock; + std::queue<std::vector<uint8> > _writeQueue; + +private: + void ReadHeaderHandlerInternal(boost::system::error_code error, size_t transferedBytes) { ReadHeaderHandler(error, transferedBytes); } + void ReadDataHandlerInternal(boost::system::error_code error, size_t transferedBytes) { ReadDataHandler(error, transferedBytes); } + + void WriteHandler(boost::system::error_code error, size_t /*transferedBytes*/) + { + if (!error) + { + std::lock_guard<std::mutex> deleteGuard(_writeLock); + + _writeQueue.pop(); + + if (!_writeQueue.empty()) + AsyncWrite(_writeQueue.front()); + } + else + CloseSocket(); + } + + tcp::socket _socket; + + uint8 _readBuffer[4096]; + + std::size_t _headerSize; +}; + +#endif // __SOCKET_H__ |