aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/server/authserver/Server/AuthSession.cpp78
-rw-r--r--src/server/authserver/Server/AuthSession.h12
2 files changed, 35 insertions, 55 deletions
diff --git a/src/server/authserver/Server/AuthSession.cpp b/src/server/authserver/Server/AuthSession.cpp
index 1cdb41ea2c7..8a727193307 100644
--- a/src/server/authserver/Server/AuthSession.cpp
+++ b/src/server/authserver/Server/AuthSession.cpp
@@ -56,18 +56,18 @@ typedef struct AUTH_LOGON_CHALLENGE_C
uint8 cmd;
uint8 error;
uint16 size;
- uint8 gamename[4];
+ uint32 gamename;
uint8 version1;
uint8 version2;
uint8 version3;
uint16 build;
- uint8 platform[4];
- uint8 os[4];
- uint8 country[4];
+ uint32 platform;
+ uint32 os;
+ uint32 country;
uint32 timezone_bias;
uint32 ip;
uint8 I_len;
- uint8 I[1];
+ char I[1];
} sAuthLogonChallenge_C;
static_assert(sizeof(sAuthLogonChallenge_C) == (1 + 1 + 2 + 4 + 1 + 1 + 1 + 2 + 4 + 4 + 4 + 4 + 4 + 1 + 1));
@@ -180,10 +180,10 @@ void AccountInfo::LoadResult(Field* fields)
//FROM account a LEFT JOIN account_access aa ON a.id = aa.AccountID LEFT JOIN account_banned ab ON ab.id = a.id AND ab.active = 1 WHERE a.username = ?
Id = fields[0].GetUInt32();
- Login = fields[1].GetString();
+ Login = fields[1].GetStringView();
IsLockedToIP = fields[2].GetBool();
- LockCountry = fields[3].GetString();
- LastIP = fields[4].GetString();
+ LockCountry = fields[3].GetStringView();
+ LastIP = fields[4].GetStringView();
FailedLogins = fields[5].GetUInt32();
IsBanned = fields[6].GetUInt64() != 0;
IsPermanenetlyBanned = fields[7].GetUInt64() != 0;
@@ -196,7 +196,9 @@ void AccountInfo::LoadResult(Field* fields)
}
AuthSession::AuthSession(tcp::socket&& socket) : Socket(std::move(socket)),
-_status(STATUS_CHALLENGE), _build(0), _timezoneOffset(0min), _expversion(0) { }
+ _status(STATUS_CHALLENGE), _locale(LOCALE_enUS), _os(0), _build(0), _expversion(0), _timezoneOffset(0min)
+{
+}
void AuthSession::Start()
{
@@ -311,28 +313,19 @@ bool AuthSession::HandleLogonChallenge()
if (challenge->size - (sizeof(sAuthLogonChallenge_C) - AUTH_LOGON_CHALLENGE_INITIAL_SIZE - 1) != challenge->I_len)
return false;
- std::string login((char const*)challenge->I, challenge->I_len);
+ std::string_view login(challenge->I, challenge->I_len);
TC_LOG_DEBUG("server.authserver", "[AuthChallenge] '{}'", login);
_build = challenge->build;
_expversion = uint8(AuthHelper::IsPostBCAcceptedClientBuild(_build) ? POST_BC_EXP_FLAG : (AuthHelper::IsPreBCAcceptedClientBuild(_build) ? PRE_BC_EXP_FLAG : NO_VALID_EXP_FLAG));
- std::array<char, 5> os;
- os.fill('\0');
- memcpy(os.data(), challenge->os, sizeof(challenge->os));
- _os = os.data();
-
- // Restore string order as its byte order is reversed
- std::reverse(_os.begin(), _os.end());
-
- _localizationName.resize(4);
- for (int i = 0; i < 4; ++i)
- _localizationName[i] = challenge->country[4 - i - 1];
+ _os = challenge->os;
+ _locale = GetLocaleByName(ClientBuild::ToCharArray(challenge->country).data());
_timezoneOffset = Minutes(challenge->timezone_bias);
// Get the account details from the account table
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_LOGONCHALLENGE);
- stmt->setString(0, login);
+ stmt->setStringView(0, login);
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
.WithPreparedCallback([this](PreparedQueryResult result) { LogonChallengeCallback(std::move(result)); }));
@@ -467,7 +460,7 @@ void AuthSession::LogonChallengeCallback(PreparedQueryResult result)
pkt << uint8(1);
TC_LOG_DEBUG("server.authserver", "'{}:{}' [AuthChallenge] account {} is using '{}' locale ({})",
- ipAddress, port, _accountInfo.Login, _localizationName, GetLocaleByName(_localizationName));
+ ipAddress, port, _accountInfo.Login, localeNames[_locale], uint32(_locale));
_status = STATUS_LOGON_PROOF;
}
@@ -524,7 +517,7 @@ bool AuthSession::HandleLogonProof()
return true;
}
- if (!VerifyVersion(logonProof->A.data(), logonProof->A.size(), logonProof->crc_hash, false))
+ if (!VerifyVersion(logonProof->A, logonProof->crc_hash, false))
{
ByteBuffer packet;
packet << uint8(AUTH_LOGON_PROOF);
@@ -542,8 +535,8 @@ bool AuthSession::HandleLogonProof()
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_UPD_LOGONPROOF);
stmt->setBinary(0, _sessionKey);
stmt->setString(1, address);
- stmt->setUInt32(2, GetLocaleByName(_localizationName));
- stmt->setString(3, _os);
+ stmt->setUInt32(2, _locale);
+ stmt->setStringView(3, ClientBuild::ToCharArray(_os).data());
stmt->setInt16(4, _timezoneOffset.count());
stmt->setString(5, _accountInfo.Login);
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
@@ -651,28 +644,19 @@ bool AuthSession::HandleReconnectChallenge()
if (challenge->size - (sizeof(sAuthLogonChallenge_C) - AUTH_LOGON_CHALLENGE_INITIAL_SIZE - 1) != challenge->I_len)
return false;
- std::string login((char const*)challenge->I, challenge->I_len);
+ std::string_view login(challenge->I, challenge->I_len);
TC_LOG_DEBUG("server.authserver", "[ReconnectChallenge] '{}'", login);
_build = challenge->build;
_expversion = uint8(AuthHelper::IsPostBCAcceptedClientBuild(_build) ? POST_BC_EXP_FLAG : (AuthHelper::IsPreBCAcceptedClientBuild(_build) ? PRE_BC_EXP_FLAG : NO_VALID_EXP_FLAG));
- std::array<char, 5> os;
- os.fill('\0');
- memcpy(os.data(), challenge->os, sizeof(challenge->os));
- _os = os.data();
-
- // Restore string order as its byte order is reversed
- std::reverse(_os.begin(), _os.end());
-
- _localizationName.resize(4);
- for (int i = 0; i < 4; ++i)
- _localizationName[i] = challenge->country[4 - i - 1];
+ _os = challenge->os;
+ _locale = GetLocaleByName(ClientBuild::ToCharArray(challenge->country).data());
_timezoneOffset = Minutes(challenge->timezone_bias);
// Get the account details from the account table
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_RECONNECTCHALLENGE);
- stmt->setString(0, login);
+ stmt->setStringView(0, login);
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
.WithPreparedCallback([this](PreparedQueryResult result) { ReconnectChallengeCallback(std::move(result)); }));
@@ -724,7 +708,7 @@ bool AuthSession::HandleReconnectProof()
if (sha.GetDigest() == reconnectProof->R2)
{
- if (!VerifyVersion(reconnectProof->R1, sizeof(reconnectProof->R1), reconnectProof->R3, true))
+ if (!VerifyVersion(reconnectProof->R1, reconnectProof->R3, true))
{
ByteBuffer packet;
packet << uint8(AUTH_RECONNECT_PROOF);
@@ -800,11 +784,7 @@ void AuthSession::RealmListCallback(PreparedQueryResult result)
std::string name = realm.Name;
if (_expversion & PRE_BC_EXP_FLAG && flag & REALM_FLAG_SPECIFYBUILD)
- {
- std::ostringstream ss;
- ss << name << " (" << buildInfo->MajorVersion << '.' << buildInfo->MinorVersion << '.' << buildInfo->BugfixVersion << ')';
- name = ss.str();
- }
+ Trinity::StringFormatTo(std::back_inserter(name), " ({}.{}.{})", buildInfo->MajorVersion, buildInfo->MinorVersion, buildInfo->BugfixVersion);
uint8 lock = (realm.AllowedSecurityLevel > _accountInfo.SecurityLevel) ? 1 : 0;
@@ -886,7 +866,7 @@ bool AuthSession::HandleXferCancel()
return false;
}
-bool AuthSession::VerifyVersion(uint8 const* a, int32 aLength, Trinity::Crypto::SHA1::Digest const& versionProof, bool isReconnect)
+bool AuthSession::VerifyVersion(std::span<uint8 const> a, Trinity::Crypto::SHA1::Digest const& versionProof, bool isReconnect)
{
if (!sConfigMgr->GetBoolDefault("StrictVersionCheck", false))
return true;
@@ -899,7 +879,7 @@ bool AuthSession::VerifyVersion(uint8 const* a, int32 aLength, Trinity::Crypto::
if (!buildInfo)
return false;
- auto platformItr = std::ranges::find(buildInfo->ExecutableHashes, ClientBuild::ToFourCC(_os), &ClientBuild::ExecutableHash::Platform);
+ auto platformItr = std::ranges::find(buildInfo->ExecutableHashes, _os, &ClientBuild::ExecutableHash::Platform);
if (platformItr == buildInfo->ExecutableHashes.end())
return true; // not filled serverside
@@ -909,9 +889,9 @@ bool AuthSession::VerifyVersion(uint8 const* a, int32 aLength, Trinity::Crypto::
versionHash = &zeros;
Trinity::Crypto::SHA1 version;
- version.UpdateData(a, aLength);
+ version.UpdateData(a);
version.UpdateData(*versionHash);
version.Finalize();
- return (versionProof == version.GetDigest());
+ return versionProof == version.GetDigest();
}
diff --git a/src/server/authserver/Server/AuthSession.h b/src/server/authserver/Server/AuthSession.h
index d1063f0d493..4b848e30428 100644
--- a/src/server/authserver/Server/AuthSession.h
+++ b/src/server/authserver/Server/AuthSession.h
@@ -27,12 +27,12 @@
#include "Socket.h"
#include "SRP6.h"
#include <boost/asio/ip/tcp.hpp>
+#include <span>
using boost::asio::ip::tcp;
class AuthHandlerTable;
class ByteBuffer;
-enum eAuthCmd : uint8;
enum AuthStatus
{
@@ -91,7 +91,7 @@ private:
void ReconnectChallengeCallback(PreparedQueryResult result);
void RealmListCallback(PreparedQueryResult result);
- bool VerifyVersion(uint8 const* a, int32 aLength, Trinity::Crypto::SHA1::Digest const& versionProof, bool isReconnect);
+ bool VerifyVersion(std::span<uint8 const> a, Trinity::Crypto::SHA1::Digest const& versionProof, bool isReconnect);
Optional<Trinity::Crypto::SRP6> _srp6;
SessionKey _sessionKey = {};
@@ -100,12 +100,12 @@ private:
AuthStatus _status;
AccountInfo _accountInfo;
Optional<std::vector<uint8>> _totpSecret;
- std::string _localizationName;
- std::string _os;
- std::string _ipCountry;
+ LocaleConstant _locale;
+ uint32 _os;
+ std::string_view _ipCountry;
uint16 _build;
- Minutes _timezoneOffset;
uint8 _expversion;
+ Minutes _timezoneOffset;
QueryCallbackProcessor _queryProcessor;
};