aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMachiavelli <none@none>2010-09-12 11:06:26 +0200
committerMachiavelli <none@none>2010-09-12 11:06:26 +0200
commit6a4c7988678f99208eefa8c636fe72da4273a75a (patch)
treeb722e74caf73e6afc85adbeb45167559d4253de3 /src
parent0612359f4b1db405cb82708f4efa1a6eb8c3f9d5 (diff)
Core/DBLayer: Replace all ad-hoc queries in AuthSocket with prepared statements
--HG-- branch : trunk
Diffstat (limited to 'src')
-rw-r--r--src/server/authserver/Server/AuthSocket.cpp106
-rw-r--r--src/server/authserver/Server/RealmSocket.cpp2
-rw-r--r--src/server/authserver/Server/RealmSocket.h4
-rw-r--r--src/server/shared/Database/Implementation/LoginDatabase.cpp12
-rw-r--r--src/server/shared/Database/Implementation/LoginDatabase.h13
5 files changed, 87 insertions, 50 deletions
diff --git a/src/server/authserver/Server/AuthSocket.cpp b/src/server/authserver/Server/AuthSocket.cpp
index 1790f3b5bf7..b58082594d4 100644
--- a/src/server/authserver/Server/AuthSocket.cpp
+++ b/src/server/authserver/Server/AuthSocket.cpp
@@ -354,34 +354,31 @@ bool AuthSocket::_HandleLogonChallenge()
///- Normalize account name
//utf8ToUpperOnlyLatin(_login); -- client already send account in expected form
- //Escape the user login to avoid further SQL injection
- //Memory will be freed on AuthSocket object destruction
- _safelogin = _login;
- LoginDatabase.escape_string(_safelogin);
-
_build = ch->build;
pkt << (uint8) AUTH_LOGON_CHALLENGE;
pkt << (uint8) 0x00;
///- Verify that this IP is not in the ip_banned table
- // No SQL injection possible (paste the IP address as passed by the socket)
- LoginDatabase.Execute("DELETE FROM ip_banned WHERE unbandate<=UNIX_TIMESTAMP() AND unbandate<>bandate");
-
- std::string address(socket().get_remote_address().c_str());
- LoginDatabase.escape_string(address);
- QueryResult result = LoginDatabase.PQuery("SELECT * FROM ip_banned WHERE ip = '%s'",address.c_str());
+ LoginDatabase.Execute(
+ LoginDatabase.GetPreparedStatement(LOGIN_SET_EXPIREDIPBANS)
+ );
+
+ const std::string& ip_address = socket().get_remote_address();
+ PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_GET_IPBANNED);
+ stmt->setString(0, ip_address);
+ PreparedQueryResult result = LoginDatabase.Query(stmt);
if (result)
{
pkt << (uint8)WOW_FAIL_BANNED;
- sLog.outBasic("[AuthChallenge] Banned ip %s tries to login!", address.c_str ());
+ sLog.outBasic("[AuthChallenge] Banned ip %s tries to login!", ip_address.c_str());
}
else
{
///- Get the account details from the account table
- // No SQL injection (escaped user name)
- PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_GET_LOGONCHALLENGE);
- stmt->setString(0, _safelogin.c_str());
+ // No SQL injection (prepared statement)
+ stmt = LoginDatabase.GetPreparedStatement(LOGIN_GET_LOGONCHALLENGE);
+ stmt->setString(0, _login);
PreparedQueryResult res2 = LoginDatabase.Query(stmt);
if (res2)
@@ -391,8 +388,8 @@ bool AuthSocket::_HandleLogonChallenge()
if (res2->GetUInt8(2) == 1) // if ip is locked
{
sLog.outStaticDebug("[AuthChallenge] Account '%s' is locked to IP - '%s'", _login.c_str(), res2->GetString(3));
- sLog.outStaticDebug("[AuthChallenge] Player address is '%s'", socket().get_remote_address().c_str());
- if (strcmp(res2->GetString(3).c_str(),socket().get_remote_address().c_str()))
+ sLog.outStaticDebug("[AuthChallenge] Player address is '%s'", ip_address.c_str());
+ if (strcmp(res2->GetString(3).c_str(), ip_address.c_str()))
{
sLog.outStaticDebug("[AuthChallenge] Account IP differs");
pkt << (uint8) WOW_FAIL_SUSPENDED;
@@ -407,20 +404,25 @@ bool AuthSocket::_HandleLogonChallenge()
if (!locked)
{
//set expired bans to inactive
- LoginDatabase.Execute("UPDATE account_banned SET active = 0 WHERE unbandate<=UNIX_TIMESTAMP() AND unbandate<>bandate");
+ LoginDatabase.Execute(
+ LoginDatabase.GetPreparedStatement(LOGIN_SET_EXPIREDACCBANS)
+ );
+
///- If the account is banned, reject the logon attempt
- QueryResult banresult = LoginDatabase.PQuery("SELECT bandate,unbandate FROM account_banned WHERE id = %u AND active = 1", res2->GetUInt32(1));
+ stmt = LoginDatabase.GetPreparedStatement(LOGIN_GET_ACCBANNED);
+ stmt->setUInt32(0, res2->GetUInt32(1));
+ PreparedQueryResult banresult = LoginDatabase.Query(stmt);
if (banresult)
{
- if ((*banresult)[0].GetUInt64() == (*banresult)[1].GetUInt64())
+ if (banresult->GetUInt64(0) == banresult->GetUInt64(1))
{
pkt << (uint8) WOW_FAIL_BANNED;
- sLog.outBasic("[AuthChallenge] Banned account %s tries to login!",_login.c_str ());
+ sLog.outBasic("[AuthChallenge] Banned account %s tries to login!", _login.c_str());
}
else
{
pkt << (uint8) WOW_FAIL_SUSPENDED;
- sLog.outBasic("[AuthChallenge] Temporarily banned account %s tries to login!",_login.c_str ());
+ sLog.outBasic("[AuthChallenge] Temporarily banned account %s tries to login!", _login.c_str());
}
}
else
@@ -662,12 +664,16 @@ bool AuthSocket::_HandleLogonProof()
if (MaxWrongPassCount > 0)
{
//Increment number of failed logins by one and if it reaches the limit temporarily ban that account or IP
- LoginDatabase.PExecute("UPDATE account SET failed_logins = failed_logins + 1 WHERE username = '%s'",_safelogin.c_str());
+ PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SET_FAILEDLOGINS);
+ stmt->setString(0, _login);
+ LoginDatabase.Execute(stmt);
+
+ stmt = LoginDatabase.GetPreparedStatement(LOGIN_GET_FAILEDLOGINS);
+ stmt->setString(0, _login);
- if (QueryResult loginfail = LoginDatabase.PQuery("SELECT id, failed_logins FROM account WHERE username = '%s'", _safelogin.c_str()))
+ if (PreparedQueryResult loginfail = LoginDatabase.Query(stmt))
{
- Field* fields = loginfail->Fetch();
- uint32 failed_logins = fields[1].GetUInt32();
+ uint32 failed_logins = loginfail->GetUInt32(1);
if (failed_logins >= MaxWrongPassCount)
{
@@ -676,20 +682,24 @@ bool AuthSocket::_HandleLogonProof()
if (WrongPassBanType)
{
- uint32 acc_id = fields[0].GetUInt32();
- LoginDatabase.PExecute("INSERT INTO account_banned VALUES ('%u',UNIX_TIMESTAMP(),UNIX_TIMESTAMP()+'%u','Trinity realmd','Failed login autoban',1)",
- acc_id, WrongPassBanTime);
+ uint32 acc_id = loginfail->GetUInt32(0);
+ stmt = LoginDatabase.GetPreparedStatement(LOGIN_SET_ACCAUTOBANNED);
+ stmt->setUInt32(0, acc_id);
+ stmt->setUInt32(1, WrongPassBanTime);
+ LoginDatabase.Execute(stmt);
+
sLog.outBasic("[AuthChallenge] account %s got banned for '%u' seconds because it failed to authenticate '%u' times",
_login.c_str(), WrongPassBanTime, failed_logins);
}
else
{
- std::string current_ip(socket().get_remote_address().c_str());
- LoginDatabase.escape_string(current_ip);
- LoginDatabase.PExecute("INSERT INTO ip_banned VALUES ('%s',UNIX_TIMESTAMP(),UNIX_TIMESTAMP()+'%u','Trinity realmd','Failed login autoban')",
- current_ip.c_str(), WrongPassBanTime);
+ stmt = LoginDatabase.GetPreparedStatement(LOGIN_SET_IPAUTOBANNED);
+ stmt->setString(0, socket().get_remote_address());
+ stmt->setUInt32(1, WrongPassBanTime);
+ LoginDatabase.Execute(stmt);
+
sLog.outBasic("[AuthChallenge] IP %s got banned for '%u' seconds because account %s failed to authenticate '%u' times",
- current_ip.c_str(), WrongPassBanTime, _login.c_str(), failed_logins);
+ socket().get_remote_address().c_str(), WrongPassBanTime, _login.c_str(), failed_logins);
}
}
}
@@ -730,9 +740,10 @@ bool AuthSocket::_HandleReconnectChallenge()
sLog.outStaticDebug("[ReconnectChallenge] name(%d): '%s'", ch->I_len, ch->I);
_login = (const char*)ch->I;
- _safelogin = _login;
- QueryResult result = LoginDatabase.PQuery ("SELECT sessionkey FROM account WHERE username = '%s'", _safelogin.c_str ());
+ PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_GET_SESSIONKEY);
+ stmt->setString(0, _login);
+ PreparedQueryResult result = LoginDatabase.Query(stmt);
// Stop if the account is not found
if (!result)
@@ -742,8 +753,7 @@ bool AuthSocket::_HandleReconnectChallenge()
return false;
}
- Field* fields = result->Fetch ();
- K.SetHexStr (fields[0].GetString ());
+ K.SetHexStr (result->GetString(0).c_str());
///- Sending response
ByteBuffer pkt;
@@ -809,17 +819,19 @@ bool AuthSocket::_HandleRealmList()
socket().recv_skip(5);
///- Get the user id (else close the connection)
- // No SQL injection (escaped user name)
+ // No SQL injection (prepared statement)
- QueryResult result = LoginDatabase.PQuery("SELECT id FROM account WHERE username = '%s'",_safelogin.c_str());
+ PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_GET_ACCIDBYNAME);
+ stmt->setString(0, _login);
+ PreparedQueryResult result = LoginDatabase.Query(stmt);
if (!result)
{
- sLog.outError("[ERROR] user %s tried to login and we cannot find him in the database.",_login.c_str());
+ sLog.outError("[ERROR] user %s tried to login and we cannot find him in the database.", _login.c_str());
socket().shutdown();
return false;
}
- uint32 id = (*result)[0].GetUInt32();
+ uint32 id = result->GetUInt32(0);
///- Update realm list if need
sRealmList->UpdateIfNeed();
@@ -845,12 +857,12 @@ bool AuthSocket::_HandleRealmList()
uint8 AmountOfCharacters;
// No SQL injection. id of realm is controlled by the database.
- result = LoginDatabase.PQuery("SELECT numchars FROM realmcharacters WHERE realmid = '%d' AND acctid='%u'",i->second.m_ID,id);
+ stmt = LoginDatabase.GetPreparedStatement(LOGIN_GET_NUMCHARSONREALM);
+ stmt->setUInt32(0, i->second.m_ID);
+ stmt->setUInt32(1, id);
+ result = LoginDatabase.Query(stmt);
if (result)
- {
- Field *fields = result->Fetch();
- AmountOfCharacters = fields[0].GetUInt8();
- }
+ AmountOfCharacters = result->GetUInt8(0);
else
AmountOfCharacters = 0;
diff --git a/src/server/authserver/Server/RealmSocket.cpp b/src/server/authserver/Server/RealmSocket.cpp
index fdb8b8b54f4..0ba3343d135 100644
--- a/src/server/authserver/Server/RealmSocket.cpp
+++ b/src/server/authserver/Server/RealmSocket.cpp
@@ -107,7 +107,7 @@ int RealmSocket::close(int)
return 0;
}
-const ACE_CString& RealmSocket::get_remote_address(void) const
+const std::string& RealmSocket::get_remote_address(void) const
{
return remote_address_;
}
diff --git a/src/server/authserver/Server/RealmSocket.h b/src/server/authserver/Server/RealmSocket.h
index 8749fba9def..1eff8c59ae3 100644
--- a/src/server/authserver/Server/RealmSocket.h
+++ b/src/server/authserver/Server/RealmSocket.h
@@ -58,7 +58,7 @@ class RealmSocket : public ACE_Svc_Handler<ACE_SOCK_STREAM, ACE_NULL_SYNCH>
bool send(const char *buf, size_t len);
- const ACE_CString& get_remote_address(void) const;
+ const std::string& get_remote_address(void) const;
virtual int open(void *);
@@ -78,7 +78,7 @@ class RealmSocket : public ACE_Svc_Handler<ACE_SOCK_STREAM, ACE_NULL_SYNCH>
private:
ACE_Message_Block input_buffer_;
Session* session_;
- ACE_CString remote_address_;
+ std::string remote_address_;
};
#endif /* __REALMSOCKET_H__ */
diff --git a/src/server/shared/Database/Implementation/LoginDatabase.cpp b/src/server/shared/Database/Implementation/LoginDatabase.cpp
index 7cf7e30b0f6..9032cb639d3 100644
--- a/src/server/shared/Database/Implementation/LoginDatabase.cpp
+++ b/src/server/shared/Database/Implementation/LoginDatabase.cpp
@@ -31,9 +31,21 @@ bool LoginDatabaseConnection::Open(const std::string& infoString)
##################################
*/
PrepareStatement(LOGIN_GET_REALMLIST, "SELECT id, name, address, port, icon, color, timezone, allowedSecurityLevel, population, gamebuild FROM realmlist WHERE color <> 3 ORDER BY name");
+ PrepareStatement(LOGIN_SET_EXPIREDIPBANS, "DELETE FROM ip_banned WHERE unbandate<=UNIX_TIMESTAMP() AND unbandate<>bandate");
+ PrepareStatement(LOGIN_SET_EXPIREDACCBANS, "UPDATE account_banned SET active = 0 WHERE unbandate<=UNIX_TIMESTAMP() AND unbandate<>bandate");
+ PrepareStatement(LOGIN_GET_IPBANNED, "SELECT * FROM ip_banned WHERE ip = ?");
+ PrepareStatement(LOGIN_SET_IPAUTOBANNED, "INSERT INTO ip_banned VALUES (?, UNIX_TIMESTAMP(), UNIX_TIMESTAMP()+?,'Trinity realmd', 'Failed login autoban')");
+ PrepareStatement(LOGIN_GET_ACCBANNED, "SELECT bandate,unbandate FROM account_banned WHERE id = ? AND active = 1");
+ PrepareStatement(LOGIN_SET_ACCAUTOBANNED, "INSERT INTO account_banned VALUES (?, UNIX_TIMESTAMP(), UNIX_TIMESTAMP()+?, 'Trinity realmd', 'Failed login autoban', 1)");
+ PrepareStatement(LOGIN_GET_SESSIONKEY, "SELECT sessionkey FROM account WHERE username = ?");
PrepareStatement(LOGIN_SET_VS, "UPDATE account SET v = ?, s = ? WHERE username = ?");
PrepareStatement(LOGIN_SET_LOGONPROOF, "UPDATE account SET sessionkey = ?, last_ip = ?, last_login = NOW(), locale = ?, failed_logins = 0 WHERE username = ?");
PrepareStatement(LOGIN_GET_LOGONCHALLENGE, "SELECT a.sha_pass_hash,a.id,a.locked,a.last_ip,aa.gmlevel,a.v,a.s FROM account a LEFT JOIN account_access aa ON (a.id = aa.id) WHERE a.username = ?");
+ PrepareStatement(LOGIN_SET_FAILEDLOGINS, "UPDATE account SET failed_logins = failed_logins + 1 WHERE username = ?");
+ PrepareStatement(LOGIN_GET_FAILEDLOGINS, "SELECT id, failed_logins FROM account WHERE username = ?");
+ PrepareStatement(LOGIN_GET_ACCIDBYNAME, "SELECT id FROM account WHERE username = ?");
+ PrepareStatement(LOGIN_GET_NUMCHARSONREALM, "SELECT numchars FROM realmcharacters WHERE realmid = ? AND acctid= ?");
+
return true;
}
diff --git a/src/server/shared/Database/Implementation/LoginDatabase.h b/src/server/shared/Database/Implementation/LoginDatabase.h
index 2dde12014d0..4792858d50e 100644
--- a/src/server/shared/Database/Implementation/LoginDatabase.h
+++ b/src/server/shared/Database/Implementation/LoginDatabase.h
@@ -44,9 +44,22 @@ enum LoginDatabaseStatements
*/
LOGIN_GET_REALMLIST,
+ LOGIN_SET_EXPIREDIPBANS,
+ LOGIN_SET_EXPIREDACCBANS,
+ LOGIN_GET_IPBANNED,
+ LOGIN_SET_IPAUTOBANNED,
+ LOGIN_GET_ACCBANNED,
+ LOGIN_SET_ACCAUTOBANNED,
+ LOGIN_GET_SESSIONKEY,
LOGIN_SET_VS,
LOGIN_SET_LOGONPROOF,
LOGIN_GET_LOGONCHALLENGE,
+ LOGIN_SET_FAILEDLOGINS,
+ LOGIN_GET_FAILEDLOGINS,
+
+ LOGIN_GET_ACCIDBYNAME,
+ LOGIN_GET_NUMCHARSONREALM,
+
MAX_LOGINDATABASE_STATEMENTS,
};