aboutsummaryrefslogtreecommitdiff
path: root/src/server/authserver/Server/AuthSession.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/server/authserver/Server/AuthSession.cpp')
-rw-r--r--src/server/authserver/Server/AuthSession.cpp63
1 files changed, 42 insertions, 21 deletions
diff --git a/src/server/authserver/Server/AuthSession.cpp b/src/server/authserver/Server/AuthSession.cpp
index 78aa977ab92..5c0628334fd 100644
--- a/src/server/authserver/Server/AuthSession.cpp
+++ b/src/server/authserver/Server/AuthSession.cpp
@@ -17,13 +17,16 @@
*/
#include "AuthSession.h"
+#include "AES.h"
#include "AuthCodes.h"
#include "Config.h"
+#include "CryptoGenerics.h"
+#include "DatabaseEnv.h"
#include "Errors.h"
#include "IPLocation.h"
#include "Log.h"
-#include "DatabaseEnv.h"
#include "RealmList.h"
+#include "SecretMgr.h"
#include "SHA1.h"
#include "TOTP.h"
#include "Util.h"
@@ -139,7 +142,7 @@ void AccountInfo::LoadResult(Field* fields)
// 0 1 2 3 4 5 6
//SELECT a.id, a.username, a.locked, a.lock_country, a.last_ip, a.failed_logins, ab.unbandate > UNIX_TIMESTAMP() OR ab.unbandate = ab.bandate,
// 7 8 9 10 11 12
- // ab.unbandate = ab.bandate, aa.gmlevel, a.token_key, a.sha_pass_hash, a.v, a.s
+ // ab.unbandate = ab.bandate, aa.gmlevel, a.totp_secret, a.sha_pass_hash, a.v, a.s
//FROM account a LEFT JOIN account_access aa ON a.id = aa.id LEFT JOIN account_banned ab ON ab.id = a.id AND ab.active = 1 WHERE a.username = ?
Id = fields[0].GetUInt32();
@@ -380,6 +383,25 @@ void AuthSession::LogonChallengeCallback(PreparedQueryResult result)
}
}
+ uint8 securityFlags = 0;
+ // Check if a TOTP token is needed
+ if (!fields[9].IsNull())
+ {
+ securityFlags = 4;
+ _totpSecret = fields[9].GetBinary();
+ if (auto const& secret = sSecretMgr->GetSecret(SECRET_TOTP_MASTER_KEY))
+ {
+ bool success = Trinity::Crypto::AEDecrypt<Trinity::Crypto::AES>(*_totpSecret, *secret);
+ if (!success)
+ {
+ pkt << uint8(WOW_FAIL_DB_BUSY);
+ TC_LOG_ERROR("server.authserver", "[AuthChallenge] Account '%s' has invalid ciphertext for TOTP token key stored", _accountInfo.Login.c_str());
+ SendPacket(pkt);
+ return;
+ }
+ }
+ }
+
// Get the password from the account table, upper it, and make the SRP6 calculation
std::string rI = fields[10].GetString();
@@ -421,13 +443,6 @@ void AuthSession::LogonChallengeCallback(PreparedQueryResult result)
pkt.append(N.AsByteArray(32).get(), 32);
pkt.append(s.AsByteArray(int32(BufferSizes::SRP_6_S)).get(), size_t(BufferSizes::SRP_6_S)); // 32 bytes
pkt.append(VersionChallenge.data(), VersionChallenge.size());
- uint8 securityFlags = 0;
-
- // Check if token is used
- _tokenKey = fields[9].GetString();
- if (!_tokenKey.empty())
- securityFlags = 4;
-
pkt << uint8(securityFlags); // security flags (0x0...0x04)
if (securityFlags & 0x01) // PIN input
@@ -548,23 +563,29 @@ bool AuthSession::HandleLogonProof()
if (!memcmp(M.AsByteArray(sha.GetLength()).get(), logonProof->M1, 20))
{
// Check auth token
- if ((logonProof->securityFlags & 0x04) || !_tokenKey.empty())
+ bool tokenSuccess = false;
+ bool sentToken = (logonProof->securityFlags & 0x04);
+ if (sentToken && _totpSecret)
{
uint8 size = *(GetReadBuffer().GetReadPointer() + sizeof(sAuthLogonProof_C));
std::string token(reinterpret_cast<char*>(GetReadBuffer().GetReadPointer() + sizeof(sAuthLogonProof_C) + sizeof(size)), size);
GetReadBuffer().ReadCompleted(sizeof(size) + size);
- uint32 validToken = TOTP::GenerateToken(_tokenKey.c_str());
- _tokenKey.clear();
+
uint32 incomingToken = atoi(token.c_str());
- if (validToken != incomingToken)
- {
- ByteBuffer packet;
- packet << uint8(AUTH_LOGON_PROOF);
- packet << uint8(WOW_FAIL_UNKNOWN_ACCOUNT);
- packet << uint16(0); // LoginFlags, 1 has account message
- SendPacket(packet);
- return true;
- }
+ tokenSuccess = Trinity::Crypto::TOTP::ValidateToken(*_totpSecret, incomingToken);
+ memset(_totpSecret->data(), 0, _totpSecret->size());
+ }
+ else if (!sentToken && !_totpSecret)
+ tokenSuccess = true;
+
+ if (!tokenSuccess)
+ {
+ ByteBuffer packet;
+ packet << uint8(AUTH_LOGON_PROOF);
+ packet << uint8(WOW_FAIL_UNKNOWN_ACCOUNT);
+ packet << uint16(0); // LoginFlags, 1 has account message
+ SendPacket(packet);
+ return true;
}
if (!VerifyVersion(logonProof->A, sizeof(logonProof->A), logonProof->crc_hash, false))