diff options
Diffstat (limited to 'src/server/authserver/Server/AuthSession.cpp')
-rw-r--r-- | src/server/authserver/Server/AuthSession.cpp | 248 |
1 files changed, 126 insertions, 122 deletions
diff --git a/src/server/authserver/Server/AuthSession.cpp b/src/server/authserver/Server/AuthSession.cpp index 446b0a43158..b2fb516804f 100644 --- a/src/server/authserver/Server/AuthSession.cpp +++ b/src/server/authserver/Server/AuthSession.cpp @@ -16,11 +16,8 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#include <memory> -#include <boost/lexical_cast.hpp> -#include <boost/asio/write.hpp> -#include <AuthSession.h> -#include <Log.h> +#include "AuthSession.h" +#include "Log.h" #include "ByteBuffer.h" #include "AuthCodes.h" #include "Database/DatabaseEnv.h" @@ -28,6 +25,7 @@ #include "openssl/crypto.h" #include "Configuration/Config.h" #include "RealmList.h" +#include <boost/lexical_cast.hpp> using boost::asio::ip::tcp; @@ -111,98 +109,88 @@ typedef struct AUTH_RECONNECT_PROOF_C #pragma pack(pop) - -typedef struct AuthHandler -{ - eAuthCmd cmd; - uint32 status; - size_t packetSize; - bool (AuthSession::*handler)(); -} AuthHandler; - #define BYTE_SIZE 32 #define REALMLIST_SKIP_PACKETS 5 +#define XFER_ACCEPT_SIZE 1 +#define XFER_RESUME_SIZE 9 +#define XFER_CANCEL_SIZE 1 -const AuthHandler table[] = +std::unordered_map<uint8, AuthHandler> AuthSession::InitHandlers() { - { AUTH_LOGON_CHALLENGE, STATUS_CONNECTED, sizeof(AUTH_LOGON_CHALLENGE_C), &AuthSession::_HandleLogonChallenge }, - { AUTH_LOGON_PROOF, STATUS_CONNECTED, sizeof(AUTH_LOGON_PROOF_C), &AuthSession::_HandleLogonProof }, - { AUTH_RECONNECT_CHALLENGE, STATUS_CONNECTED, sizeof(AUTH_LOGON_CHALLENGE_C), &AuthSession::_HandleReconnectChallenge }, - { AUTH_RECONNECT_PROOF, STATUS_CONNECTED, sizeof(AUTH_RECONNECT_PROOF_C), &AuthSession::_HandleReconnectProof }, - { REALM_LIST, STATUS_AUTHED, REALMLIST_SKIP_PACKETS, &AuthSession::_HandleRealmList } -}; + std::unordered_map<uint8, AuthHandler> handlers; + + handlers[AUTH_LOGON_CHALLENGE] = { STATUS_CONNECTED, sizeof(AUTH_LOGON_CHALLENGE_C), &AuthSession::HandleLogonChallenge }; + handlers[AUTH_LOGON_PROOF] = { STATUS_CONNECTED, sizeof(AUTH_LOGON_PROOF_C), &AuthSession::HandleLogonProof }; + handlers[AUTH_RECONNECT_CHALLENGE] = { STATUS_CONNECTED, sizeof(AUTH_LOGON_CHALLENGE_C), &AuthSession::HandleReconnectChallenge }; + handlers[AUTH_RECONNECT_PROOF] = { STATUS_CONNECTED, sizeof(AUTH_RECONNECT_PROOF_C), &AuthSession::HandleReconnectProof }; + handlers[REALM_LIST] = { STATUS_AUTHED, REALMLIST_SKIP_PACKETS, &AuthSession::HandleRealmList }; + handlers[XFER_ACCEPT] = { STATUS_AUTHED, XFER_ACCEPT_SIZE, &AuthSession::HandleXferAccept }; + handlers[XFER_RESUME] = { STATUS_AUTHED, XFER_RESUME_SIZE, &AuthSession::HandleXferResume }; + handlers[XFER_CANCEL] = { STATUS_AUTHED, XFER_CANCEL_SIZE, &AuthSession::HandleXferCancel }; + + return handlers; +} -void AuthSession::AsyncReadHeader() -{ - auto self(shared_from_this()); +std::unordered_map<uint8, AuthHandler> const Handlers = AuthSession::InitHandlers(); - _socket.async_read_some(boost::asio::buffer(_readBuffer, 1), [this, self](boost::system::error_code error, size_t transferedBytes) +void AuthSession::ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) +{ + if (!error && transferedBytes == 1) { - if (!error && transferedBytes == 1) + uint8 cmd = GetReadBuffer()[0]; + auto itr = Handlers.find(cmd); + if (itr != Handlers.end()) { - for (const AuthHandler& entry : table) + // Handle dynamic size packet + if (cmd == AUTH_LOGON_CHALLENGE || cmd == AUTH_RECONNECT_CHALLENGE) { - if ((uint8)entry.cmd == _readBuffer[0] && (entry.status == STATUS_CONNECTED || (_isAuthenticated && entry.status == STATUS_AUTHED))) - { - // Handle dynamic size packet - if (_readBuffer[0] == AUTH_LOGON_CHALLENGE || _readBuffer[0] == AUTH_RECONNECT_CHALLENGE) - { - _socket.read_some(boost::asio::buffer(&_readBuffer[1], sizeof(uint8) + sizeof(uint16))); //error + size + ReadData(sizeof(uint8) + sizeof(uint16), sizeof(cmd)); //error + size + sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer()); - AsyncReadData(entry.handler, *reinterpret_cast<uint16*>(&_readBuffer[2]), sizeof(uint8) + sizeof(uint8) + sizeof(uint16)); // cmd + error + size - } - else - { - AsyncReadData(entry.handler, entry.packetSize, sizeof(uint8)); - } - break; - } + AsyncReadData(challenge->size, sizeof(uint8) + sizeof(uint8) + sizeof(uint16)); // cmd + error + size } + else + AsyncReadData(itr->second.packetSize, sizeof(uint8)); } - else - { - CloseSocket(); - } - }); + } + else + CloseSocket(); } -void AuthSession::AsyncReadData(bool (AuthSession::*handler)(), size_t dataSize, size_t bufferOffSet) +void AuthSession::ReadDataHandler(boost::system::error_code error, size_t transferedBytes) { - auto self(shared_from_this()); - - _socket.async_read_some(boost::asio::buffer(&_readBuffer[bufferOffSet], dataSize), [handler, this, self](boost::system::error_code error, size_t transferedBytes) + if (!error && transferedBytes > 0) { - if (!error && transferedBytes > 0) - { - if (!(*this.*handler)()) - { - CloseSocket(); - return; - } - - AsyncReadHeader(); - } - else + if (!(*this.*Handlers.at(GetReadBuffer()[0]).handler)()) { CloseSocket(); + return; } - }); + + AsyncReadHeader(); + } + else + CloseSocket(); } -void AuthSession::AsyncWrite(std::size_t length) +void AuthSession::AsyncWrite(ByteBuffer const& packet) { - boost::asio::async_write(_socket, boost::asio::buffer(_writeBuffer, length), [this](boost::system::error_code error, std::size_t /*length*/) - { - if (error) - { - CloseSocket(); - } - }); + std::vector<uint8> data(packet.size()); + std::memcpy(data.data(), packet.contents(), packet.size()); + + std::lock_guard<std::mutex> guard(_writeLock); + + bool needsWriteStart = _writeQueue.empty(); + + _writeQueue.push(std::move(data)); + + if (needsWriteStart) + AsyncWrite(_writeQueue.front()); } -bool AuthSession::_HandleLogonChallenge() +bool AuthSession::HandleLogonChallenge() { - sAuthLogonChallenge_C *challenge = (sAuthLogonChallenge_C*)&_readBuffer; + sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer()); //TC_LOG_DEBUG("server.authserver", "[AuthChallenge] got full packet, %#04x bytes", challenge->size); TC_LOG_DEBUG("server.authserver", "[AuthChallenge] name(%d): '%s'", challenge->I_len, challenge->I); @@ -226,8 +214,8 @@ bool AuthSession::_HandleLogonChallenge() // Verify that this IP is not in the ip_banned table LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS)); - std::string const& ipAddress = _socket.remote_endpoint().address().to_string(); - unsigned short port = _socket.remote_endpoint().port(); + std::string ipAddress = GetRemoteIpAddress().to_string(); + uint16 port = GetRemotePort(); PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_BANNED); stmt->setString(0, ipAddress); @@ -413,20 +401,17 @@ bool AuthSession::_HandleLogonChallenge() pkt << uint8(WOW_FAIL_UNKNOWN_ACCOUNT); } - std::memcpy(_writeBuffer, (char const*)pkt.contents(), pkt.size()); - - AsyncWrite(pkt.size()); - + AsyncWrite(pkt); return true; } // Logon Proof command handler -bool AuthSession::_HandleLogonProof() +bool AuthSession::HandleLogonProof() { TC_LOG_DEBUG("server.authserver", "Entering _HandleLogonProof"); // Read the packet - sAuthLogonProof_C *logonProof = (sAuthLogonProof_C*)&_readBuffer; + sAuthLogonProof_C *logonProof = reinterpret_cast<sAuthLogonProof_C*>(GetReadBuffer()); // If the client has no valid version if (_expversion == NO_VALID_EXP_FLAG) @@ -514,12 +499,12 @@ bool AuthSession::_HandleLogonProof() // Check if SRP6 results match (password is correct), else send an error if (!memcmp(M.AsByteArray().get(), logonProof->M1, 20)) { - TC_LOG_DEBUG("server.authserver", "'%s:%d' User '%s' successfully authenticated", GetRemoteIpAddress().c_str(), GetRemotePort(), _login.c_str()); + TC_LOG_DEBUG("server.authserver", "'%s:%d' User '%s' successfully authenticated", GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str()); // Update the sessionkey, last_ip, last login time and reset number of failed logins in the account table for this account PreparedStatement *stmt = LoginDatabase.GetPreparedStatement(LOGIN_UPD_LOGONPROOF); stmt->setString(0, K.AsHexStr()); - stmt->setString(1, GetRemoteIpAddress().c_str()); + stmt->setString(1, GetRemoteIpAddress().to_string().c_str()); stmt->setUInt32(2, GetLocaleByName(_localizationName)); stmt->setString(3, _os); stmt->setString(4, _login); @@ -546,12 +531,17 @@ bool AuthSession::_HandleLogonProof() delete[] token; if (validToken != incomingToken) { - char data[] = { AUTH_LOGON_PROOF, WOW_FAIL_UNKNOWN_ACCOUNT, 3, 0 }; - socket().send(data, sizeof(data)); - return false; + ByteBuffer packet; + packet << uint8(AUTH_LOGON_PROOF); + packet << uint8(WOW_FAIL_UNKNOWN_ACCOUNT); + packet << uint8(3); + packet << uint8(0); + AsyncWrite(packet); + return false; }*/ } + ByteBuffer packet; if (_expversion & POST_BC_EXP_FLAG) // 2.x and 3.x clients { sAuthLogonProof_S proof; @@ -562,8 +552,8 @@ bool AuthSession::_HandleLogonProof() proof.unk2 = 0x00; // SurveyId proof.unk3 = 0x00; - std::memcpy(_writeBuffer, (char *)&proof, sizeof(proof)); - AsyncWrite(sizeof(proof)); + packet.resize(sizeof(proof)); + std::memcpy(packet.contents(), &proof, sizeof(proof)); } else { @@ -573,21 +563,24 @@ bool AuthSession::_HandleLogonProof() proof.error = 0; proof.unk2 = 0x00; - std::memcpy(_writeBuffer, (char *)&proof, sizeof(proof)); - AsyncWrite(sizeof(proof)); + packet.resize(sizeof(proof)); + std::memcpy(packet.contents(), &proof, sizeof(proof)); } + AsyncWrite(packet); _isAuthenticated = true; } else { - char data[4] = { AUTH_LOGON_PROOF, WOW_FAIL_UNKNOWN_ACCOUNT, 3, 0 }; - - std::memcpy(_writeBuffer, data, sizeof(data)); - AsyncWrite(sizeof(data)); + ByteBuffer packet; + packet << uint8(AUTH_LOGON_PROOF); + packet << uint8(WOW_FAIL_UNKNOWN_ACCOUNT); + packet << uint8(3); + packet << uint8(0); + AsyncWrite(packet); TC_LOG_DEBUG("server.authserver", "'%s:%d' [AuthChallenge] account %s tried to login with invalid password!", - GetRemoteIpAddress().c_str(), GetRemotePort(), _login.c_str()); + GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str()); uint32 MaxWrongPassCount = sConfigMgr->GetIntDefault("WrongPass.MaxCount", 0); @@ -596,7 +589,7 @@ bool AuthSession::_HandleLogonProof() { PreparedStatement* logstmt = LoginDatabase.GetPreparedStatement(LOGIN_INS_FALP_IP_LOGGING); logstmt->setString(0, _login); - logstmt->setString(1, GetRemoteIpAddress()); + logstmt->setString(1, GetRemoteIpAddress().to_string()); logstmt->setString(2, "Logged on failed AccountLogin due wrong password"); LoginDatabase.Execute(logstmt); @@ -630,17 +623,17 @@ bool AuthSession::_HandleLogonProof() LoginDatabase.Execute(stmt); TC_LOG_DEBUG("server.authserver", "'%s:%d' [AuthChallenge] account %s got banned for '%u' seconds because it failed to authenticate '%u' times", - GetRemoteIpAddress().c_str(), GetRemotePort(), _login.c_str(), WrongPassBanTime, failed_logins); + GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str(), WrongPassBanTime, failed_logins); } else { stmt = LoginDatabase.GetPreparedStatement(LOGIN_INS_IP_AUTO_BANNED); - stmt->setString(0, GetRemoteIpAddress()); + stmt->setString(0, GetRemoteIpAddress().to_string()); stmt->setUInt32(1, WrongPassBanTime); LoginDatabase.Execute(stmt); TC_LOG_DEBUG("server.authserver", "'%s:%d' [AuthChallenge] IP got banned for '%u' seconds because account %s failed to authenticate '%u' times", - GetRemoteIpAddress().c_str(), GetRemotePort(), WrongPassBanTime, _login.c_str(), failed_logins); + GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), WrongPassBanTime, _login.c_str(), failed_logins); } } } @@ -650,10 +643,10 @@ bool AuthSession::_HandleLogonProof() return true; } -bool AuthSession::_HandleReconnectChallenge() +bool AuthSession::HandleReconnectChallenge() { TC_LOG_DEBUG("server.authserver", "Entering _HandleReconnectChallenge"); - sAuthLogonChallenge_C *challenge = (sAuthLogonChallenge_C*)&_readBuffer; + sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer()); //TC_LOG_DEBUG("server.authserver", "[AuthChallenge] got full packet, %#04x bytes", challenge->size); TC_LOG_DEBUG("server.authserver", "[AuthChallenge] name(%d): '%s'", challenge->I_len, challenge->I); @@ -668,7 +661,7 @@ bool AuthSession::_HandleReconnectChallenge() if (!result) { TC_LOG_ERROR("server.authserver", "'%s:%d' [ERROR] user %s tried to login and we cannot find his session key in the database.", - GetRemoteIpAddress().c_str(), GetRemotePort(), _login.c_str()); + GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str()); return false; } @@ -697,15 +690,14 @@ bool AuthSession::_HandleReconnectChallenge() pkt.append(_reconnectProof.AsByteArray(16).get(), 16); // 16 bytes random pkt << uint64(0x00) << uint64(0x00); // 16 bytes zeros - std::memcpy(_writeBuffer, (char const*)pkt.contents(), pkt.size()); - AsyncWrite(pkt.size()); + AsyncWrite(pkt); return true; } -bool AuthSession::_HandleReconnectProof() +bool AuthSession::HandleReconnectProof() { TC_LOG_DEBUG("server.authserver", "Entering _HandleReconnectProof"); - sAuthReconnectProof_C *reconnectProof = (sAuthReconnectProof_C*)&_readBuffer; + sAuthReconnectProof_C *reconnectProof = reinterpret_cast<sAuthReconnectProof_C*>(GetReadBuffer()); if (_login.empty() || !_reconnectProof.GetNumBytes() || !K.GetNumBytes()) return false; @@ -726,20 +718,19 @@ bool AuthSession::_HandleReconnectProof() pkt << uint8(AUTH_RECONNECT_PROOF); pkt << uint8(0x00); pkt << uint16(0x00); // 2 bytes zeros - std::memcpy(_writeBuffer, (char const*)pkt.contents(), pkt.size()); - AsyncWrite(pkt.size()); + AsyncWrite(pkt); _isAuthenticated = true; return true; } else { - TC_LOG_ERROR("server.authserver", "'%s:%d' [ERROR] user %s tried to login, but session is invalid.", GetRemoteIpAddress().c_str(), + TC_LOG_ERROR("server.authserver", "'%s:%d' [ERROR] user %s tried to login, but session is invalid.", GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str()); return false; } } -bool AuthSession::_HandleRealmList() +bool AuthSession::HandleRealmList() { TC_LOG_DEBUG("server.authserver", "Entering _HandleRealmList"); @@ -750,7 +741,7 @@ bool AuthSession::_HandleRealmList() PreparedQueryResult result = LoginDatabase.Query(stmt); if (!result) { - TC_LOG_ERROR("server.authserver", "'%s:%d' [ERROR] user %s tried to login but we cannot find him in the database.", GetRemoteIpAddress().c_str(), + TC_LOG_ERROR("server.authserver", "'%s:%d' [ERROR] user %s tried to login but we cannot find him in the database.", GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str()); return false; } @@ -808,7 +799,7 @@ bool AuthSession::_HandleRealmList() pkt << lock; // if 1, then realm locked pkt << uint8(flag); // RealmFlags pkt << name; - pkt << boost::lexical_cast<std::string>(realm.GetAddressForClient(_socket.remote_endpoint().address())); + pkt << boost::lexical_cast<std::string>(realm.GetAddressForClient(GetRemoteIpAddress())); pkt << realm.populationLevel; pkt << AmountOfCharacters; pkt << realm.timezone; // realm category @@ -852,10 +843,32 @@ bool AuthSession::_HandleRealmList() hdr << uint16(pkt.size() + RealmListSizeBuffer.size()); hdr.append(RealmListSizeBuffer); // append RealmList's size buffer hdr.append(pkt); // append realms in the realmlist + AsyncWrite(hdr); + return true; +} + +// Resume patch transfer +bool AuthSession::HandleXferResume() +{ + TC_LOG_DEBUG("server.authserver", "Entering _HandleXferResume"); + //uint8 + //uint64 + return true; +} - std::memcpy(_writeBuffer, (char const*)hdr.contents(), hdr.size()); - AsyncWrite(hdr.size()); +// Cancel patch transfer +bool AuthSession::HandleXferCancel() +{ + TC_LOG_DEBUG("server.authserver", "Entering _HandleXferCancel"); + //uint8 + return false; +} +// Accept patch transfer +bool AuthSession::HandleXferAccept() +{ + TC_LOG_DEBUG("server.authserver", "Entering _HandleXferAccept"); + //uint8 return true; } @@ -889,12 +902,3 @@ void AuthSession::SetVSFields(const std::string& rI) stmt->setString(2, _login); LoginDatabase.Execute(stmt); } - -void AuthSession::CloseSocket() -{ - boost::system::error_code socketError; - _socket.close(socketError); - if (socketError) - TC_LOG_DEBUG("server.authserver", "Account '%s' errored when closing socket: %i (%s)", - _login.c_str(), socketError.value(), socketError.message().c_str()); -} |