diff options
author | Machiavelli <none@none> | 2010-09-24 22:16:21 +0200 |
---|---|---|
committer | Machiavelli <none@none> | 2010-09-24 22:16:21 +0200 |
commit | 3c6dc320308880bde4ef9eddd695db28a74aa0d9 (patch) | |
tree | f209e6c487e436fc1cd978487dddf3604ce2b594 /src/server/shared/Database | |
parent | b46b498141cc167163c6112e8e2bfa32fec2d7dc (diff) |
Core/DBLayer:
- Rewrite Field class to be able to store both binary prepared statement data and data from adhoc query resultsets
- Buffer the data of prepared statements using ResultSet and Field classes and let go of mysql c api structures after PreparedResultSet constructor. Fixes a race condition and thus a possible crash/data corruption (issue pointed out to Derex, basic suggestion by raczman)
- Conform PreparedResultSet and ResultSet to the same design standards, and using Field class as data buffer class for both
* NOTE: This means the fetching methods are uniform again, using ¨Field* fields = result->Fetch();¨ and access to elements trough fields[x].
* NOTE: for access to the correct row in prepared statements, ¨Field* fields = result->Fetch();¨ must ALWAYS be called inside the do { }while(result->NextRow()) loop.
* NOTE: This means that Field::GetString() returns std::string object and Field::GetCString() returns const char* pointer.
Still experimental and all that jazz, not recommended for production servers until feedback is given.
--HG--
branch : trunk
Diffstat (limited to 'src/server/shared/Database')
-rw-r--r-- | src/server/shared/Database/DatabaseWorkerPool.h | 1 | ||||
-rw-r--r-- | src/server/shared/Database/Field.cpp | 66 | ||||
-rw-r--r-- | src/server/shared/Database/Field.h | 320 | ||||
-rw-r--r-- | src/server/shared/Database/MySQLConnection.cpp | 70 | ||||
-rw-r--r-- | src/server/shared/Database/MySQLConnection.h | 1 | ||||
-rw-r--r-- | src/server/shared/Database/QueryHolder.cpp | 15 | ||||
-rw-r--r-- | src/server/shared/Database/QueryResult.cpp | 304 | ||||
-rwxr-xr-x | src/server/shared/Database/QueryResult.h | 197 | ||||
-rw-r--r-- | src/server/shared/Database/SQLStorageImpl.h | 2 |
9 files changed, 517 insertions, 459 deletions
diff --git a/src/server/shared/Database/DatabaseWorkerPool.h b/src/server/shared/Database/DatabaseWorkerPool.h index 2d55a6f1b6b..b2afdbd3bad 100644 --- a/src/server/shared/Database/DatabaseWorkerPool.h +++ b/src/server/shared/Database/DatabaseWorkerPool.h @@ -301,7 +301,6 @@ class DatabaseWorkerPool if (!ret || !ret->GetRowCount()) return PreparedQueryResult(NULL); - ret->NextRow(); return PreparedQueryResult(ret); } diff --git a/src/server/shared/Database/Field.cpp b/src/server/shared/Database/Field.cpp index ac83bf055fb..9ef23f3ad17 100644 --- a/src/server/shared/Database/Field.cpp +++ b/src/server/shared/Database/Field.cpp @@ -1,6 +1,4 @@ /* - * Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/> - * * Copyright (C) 2008-2010 Trinity <http://www.trinitycore.org/> * * This program is free software; you can redistribute it and/or modify @@ -20,58 +18,42 @@ #include "Field.h" -Field::Field() : -mValue(NULL), mType(DB_TYPE_UNKNOWN) +Field::Field() { + data.value = NULL; + data.type = MYSQL_TYPE_NULL; + data.length = 0; } -Field::Field(Field &f) +Field::~Field() { - const char *value; - - value = f.GetString(); - - if (value) - { - mValue = new char[strlen(value) + 1]; - if (mValue) - strcpy(mValue, value); - } - else - mValue = NULL; - - mType = f.GetType(); + CleanUp(); } -Field::Field(const char *value, enum Field::DataTypes type) : -mType(type) +void Field::SetByteValue(void* newValue, const size_t newSize, enum_field_types newType, uint32 length) { - if (value) + // This value stores raw bytes that have to be explicitly casted later + if (newValue) { - mValue = new char[strlen(value) + 1]; - if (mValue) - strcpy(mValue, value); + data.value = new char [newSize]; + memcpy(data.value, newValue, newSize); + data.length = length; } - else - mValue = NULL; + data.type = newType; + data.raw = true; } -Field::~Field() -{ - if (mValue) - delete [] mValue; -} - -void Field::SetValue(const char *value) +void Field::SetStructuredValue(char* newValue, enum_field_types newType, const size_t newSize) { - if (mValue) - delete [] mValue; - - if (value) + // This value stores somewhat structured data that needs function style casting + if (newValue) { - mValue = new char[strlen(value) + 1]; - strcpy(mValue, value); + size_t size = strlen(newValue); + data.value = new char [size+1]; + strcpy(data.value, newValue); + data.length = size; } - else - mValue = NULL; + + data.type = newType; + data.raw = false; } diff --git a/src/server/shared/Database/Field.h b/src/server/shared/Database/Field.h index 2885d9eff8b..0870f1c8562 100644 --- a/src/server/shared/Database/Field.h +++ b/src/server/shared/Database/Field.h @@ -1,6 +1,4 @@ /* - * Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/> - * * Copyright (C) 2008-2010 Trinity <http://www.trinitycore.org/> * * This program is free software; you can redistribute it and/or modify @@ -18,74 +16,308 @@ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */ -#if !defined(FIELD_H) -#define FIELD_H -#include <iostream> +#ifndef _FIELD_H +#define _FIELD_H + #include "Common.h" +#include "Log.h" + +#include <mysql.h> class Field { + friend class ResultSet; + friend class PreparedResultSet; + public: + + bool GetBool() const // Wrapper, actually gets integer + { + return GetUInt8() == 1 ? true : false; + } - enum DataTypes + uint8 GetUInt8() const { - DB_TYPE_UNKNOWN = 0x00, - DB_TYPE_STRING = 0x01, - DB_TYPE_INTEGER = 0x02, - DB_TYPE_FLOAT = 0x03, - DB_TYPE_BOOL = 0x04 - }; + if (!data.value) + return 0; - Field(); - Field(Field &f); - Field(const char *value, enum DataTypes type); + #ifdef TRINITY_DEBUG + if (!IsNumeric()) + { + sLog.outSQLDriver("Error: GetUInt8() on non-numeric field."); + return 0; + } + #endif + if (data.raw) + return *reinterpret_cast<uint8*>(data.value); + return static_cast<uint8>(atol((char*)data.value)); + } - ~Field(); + int8 GetInt8() const + { + if (!data.value) + return 0; + + #ifdef TRINITY_DEBUG + if (!IsNumeric()) + { + sLog.outSQLDriver("Error: GeInt8() on non-numeric field."); + return 0; + } + #endif + if (data.raw) + return *reinterpret_cast<int8*>(data.value); + return static_cast<int8>(atol((char*)data.value)); + } + + uint16 GetUInt16() const + { + if (!data.value) + return 0; - enum DataTypes GetType() const { return mType; } + #ifdef TRINITY_DEBUG + if (!IsNumeric()) + { + sLog.outSQLDriver("Error: GetUInt16() on non-numeric field."); + return 0; + } + #endif + if (data.raw) + return *reinterpret_cast<uint16*>(data.value); + return static_cast<uint16>(atol((char*)data.value)); + } - const char *GetString() const { return mValue; } - std::string GetCppString() const + int16 GetInt16() const { - return mValue ? mValue : ""; // std::string s = 0 have undefine result in C++ + if (!data.value) + return 0; + + #ifdef TRINITY_DEBUG + if (!IsNumeric()) + { + sLog.outSQLDriver("Error: GetInt16() on non-numeric field."); + return 0; + } + #endif + if (data.raw) + return *reinterpret_cast<int16*>(data.value); + return static_cast<int16>(atol((char*)data.value)); } - float GetFloat() const { return mValue ? static_cast<float>(atof(mValue)) : 0.0f; } - bool GetBool() const { return mValue ? atoi(mValue) > 0 : false; } - int32 GetInt32() const { return mValue ? static_cast<int32>(atol(mValue)) : int32(0); } - uint8 GetUInt8() const { return mValue ? static_cast<uint8>(atol(mValue)) : uint8(0); } - uint16 GetUInt16() const { return mValue ? static_cast<uint16>(atol(mValue)) : uint16(0); } - int16 GetInt16() const { return mValue ? static_cast<int16>(atol(mValue)) : int16(0); } - uint32 GetUInt32() const { return mValue ? static_cast<uint32>(atol(mValue)) : uint32(0); } - uint64 GetUInt64() const + + uint32 GetUInt32() const { - if(mValue) + if (!data.value) + return 0; + + #ifdef TRINITY_DEBUG + if (!IsNumeric()) { - uint64 value; - sscanf(mValue,UI64FMTD,&value); - return value; + sLog.outSQLDriver("Error: GetUInt32() on non-numeric field."); + return 0; } - else + #endif + if (data.raw) + return *reinterpret_cast<uint32*>(data.value); + return static_cast<uint32>(atol((char*)data.value)); + } + + int32 GetInt32() const + { + if (!data.value) + return 0; + + #ifdef TRINITY_DEBUG + if (!IsNumeric()) + { + sLog.outSQLDriver("Error: GetInt32() on non-numeric field."); return 0; + } + #endif + if (data.raw) + return *reinterpret_cast<int32*>(data.value); + return static_cast<int32>(atol((char*)data.value)); } - uint64 GetInt64() const + + uint64 GetUInt64() const { - if(mValue) + if (!data.value) + return 0; + + #ifdef TRINITY_DEBUG + if (!IsNumeric()) { - int64 value; - sscanf(mValue,SI64FMTD,&value); - return value; + sLog.outSQLDriver("Error: GetUInt64() on non-numeric field."); + return 0; } - else + #endif + if (data.raw) + return *reinterpret_cast<uint64*>(data.value); + return static_cast<uint64>(atol((char*)data.value)); + } + + int64 GetInt64() const + { + if (!data.value) + return 0; + + #ifdef TRINITY_DEBUG + if (!IsNumeric()) + { + sLog.outSQLDriver("Error: GetInt64() on non-numeric field."); return 0; + } + #endif + if (data.raw) + return *reinterpret_cast<int64*>(data.value); + return static_cast<int64>(atol((char*)data.value)); + } + + float GetFloat() const + { + if (!data.value) + return 0.0f; + + #ifdef TRINITY_DEBUG + if (!IsNumeric()) + { + sLog.outSQLDriver("Error: GetFloat() on non-numeric field."); + return 0.0f; + } + #endif + if (data.raw) + return *reinterpret_cast<float*>(data.value); + return static_cast<float>(atof((char*)data.value)); + } + + double GetDouble() const + { + if (!data.value) + return 0.0f; + + #ifdef TRINITY_DEBUG + if (!IsNumeric()) + { + sLog.outSQLDriver("Error: GetDouble() on non-numeric field."); + return 0.0f; + } + #endif + if (data.raw) + return *reinterpret_cast<double*>(data.value); + return static_cast<double>(atof((char*)data.value)); + } + + const char* GetCString() const + { + if (!data.value) + return NULL; + + #ifdef TRINITY_DEBUG + if (IsNumeric()) + { + sLog.outSQLDriver("Error: GetCString() on numeric field."); + return NULL; + } + #endif + return static_cast<const char*>(data.value); + } + + std::string GetString() const + { + if (!data.value) + return ""; + + if (data.raw) + { + const char* string = GetCString(); + if (!string) + string = ""; + return std::string(string, data.length); + } + return std::string((char*)data.value); + } + + protected: + Field(); + ~Field(); + + struct + { + enum_field_types type; // Field type + void* value; // Actual data in memory + bool raw; // Raw bytes? (Prepared statement or adhoc) + uint32 length; // Length (prepared strings only) + } data; + + void SetByteValue(void* newValue, const size_t newSize, enum_field_types newType, uint32 length); + void SetStructuredValue(char* newValue, enum_field_types newType, const size_t newSize); + + void CleanUp() + { + delete[] (data.value); + data.value = NULL; } - void SetType(enum DataTypes type) { mType = type; } + static size_t SizeForType(MYSQL_FIELD* field) + { + switch (field->type) + { + case MYSQL_TYPE_NULL: + return 0; + case MYSQL_TYPE_TINY: + return 1; + case MYSQL_TYPE_YEAR: + case MYSQL_TYPE_SHORT: + return 2; + case MYSQL_TYPE_INT24: + case MYSQL_TYPE_LONG: + case MYSQL_TYPE_FLOAT: + return 4; + case MYSQL_TYPE_DOUBLE: + case MYSQL_TYPE_LONGLONG: + case MYSQL_TYPE_BIT: + return 8; + + case MYSQL_TYPE_TIMESTAMP: + case MYSQL_TYPE_DATE: + case MYSQL_TYPE_TIME: + case MYSQL_TYPE_DATETIME: + return sizeof(MYSQL_TIME); + + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_VAR_STRING: + return field->max_length + 1; - void SetValue(const char *value); + case MYSQL_TYPE_DECIMAL: + case MYSQL_TYPE_NEWDECIMAL: + return 64; - private: - char *mValue; - enum DataTypes mType; + case MYSQL_TYPE_GEOMETRY: + /* + Following types are not sent over the wire: + MYSQL_TYPE_ENUM: + MYSQL_TYPE_SET: + */ + default: + sLog.outSQLDriver("SQL::SizeForType(): invalid field type %u", uint32(field->type)); + return 0; + } + } + + bool IsNumeric() const + { + return (data.type == MYSQL_TYPE_TINY || + data.type == MYSQL_TYPE_SHORT || + data.type == MYSQL_TYPE_INT24 || + data.type == MYSQL_TYPE_LONG || + data.type == MYSQL_TYPE_FLOAT || + data.type == MYSQL_TYPE_DOUBLE || + data.type == MYSQL_TYPE_LONGLONG ); + } }; + #endif diff --git a/src/server/shared/Database/MySQLConnection.cpp b/src/server/shared/Database/MySQLConnection.cpp index 54bf97ce601..f3fa2f352da 100644 --- a/src/server/shared/Database/MySQLConnection.cpp +++ b/src/server/shared/Database/MySQLConnection.cpp @@ -222,14 +222,63 @@ bool MySQLConnection::Execute(PreparedStatement* stmt) m_mStmt->ClearParameters(); return false; } - else + + #ifdef SQLQUERY_LOG + sLog.outSQLDriver("[%u ms] Prepared SQL: %u", getMSTimeDiff(_s, getMSTime()), index); + #endif + m_mStmt->ClearParameters(); + return true; + } +} + +bool MySQLConnection::_Query(PreparedStatement* stmt, MYSQL_RES **pResult, MYSQL_FIELD **pFields, uint64* pRowCount, uint32* pFieldCount) +{ + if (!m_Mysql) + return false; + + uint32 index = stmt->m_index; + { + // guarded block for thread-safe mySQL request + ACE_Guard<ACE_Thread_Mutex> query_connection_guard(m_Mutex); + + MySQLPreparedStatement* m_mStmt = GetPreparedStatement(index); + ASSERT(m_mStmt); // Can only be null if preparation failed, server side error or bad query + m_mStmt->m_stmt = stmt; // Cross reference them for debug output + stmt->m_stmt = m_mStmt; // TODO: Cleaner way + + stmt->BindParameters(); + + MYSQL_STMT* msql_STMT = m_mStmt->GetSTMT(); + MYSQL_BIND* msql_BIND = m_mStmt->GetBind(); + + #ifdef SQLQUERY_LOG + uint32 _s = getMSTime(); + #endif + if (mysql_stmt_bind_param(msql_STMT, msql_BIND)) { - #ifdef SQLQUERY_LOG - sLog.outSQLDriver("[%u ms] Prepared SQL: %u", getMSTimeDiff(_s, getMSTime()), index); - #endif + sLog.outSQLDriver("[ERROR]: PreparedStatement (id: %u) error binding params: %s", index, mysql_stmt_error(msql_STMT)); + m_mStmt->ClearParameters(); + return false; + } + + if (mysql_stmt_execute(msql_STMT)) + { + sLog.outSQLDriver("[ERROR]: PreparedStatement (id: %u) error executing: %s", index, mysql_stmt_error(msql_STMT)); m_mStmt->ClearParameters(); - return true; + return false; } + + #ifdef SQLQUERY_LOG + sLog.outSQLDriver("[%u ms] Prepared SQL: %u", getMSTimeDiff(_s, getMSTime()), index); + #endif + m_mStmt->ClearParameters(); + + *pResult = mysql_stmt_result_metadata(msql_STMT); + *pRowCount = /*mysql_affected_rows(m_Mysql); //* or*/ mysql_stmt_num_rows(msql_STMT); + *pFieldCount = mysql_stmt_field_count(msql_STMT); + + return true; + } } @@ -336,10 +385,17 @@ void MySQLConnection::PrepareStatement(uint32 index, const char* sql) PreparedResultSet* MySQLConnection::Query(PreparedStatement* stmt) { - this->Execute(stmt); + MYSQL_RES *result = NULL; + MYSQL_FIELD *fields = NULL; + uint64 rowCount = 0; + uint32 fieldCount = 0; + + if (!_Query(stmt, &result, &fields, &rowCount, &fieldCount)) + return NULL; + if (mysql_more_results(m_Mysql)) { mysql_next_result(m_Mysql); } - return new PreparedResultSet(stmt->m_stmt->GetSTMT()); + return new PreparedResultSet(stmt->m_stmt->GetSTMT(), result, fields, rowCount, fieldCount); } diff --git a/src/server/shared/Database/MySQLConnection.h b/src/server/shared/Database/MySQLConnection.h index b707f8a4675..09c30e7073e 100644 --- a/src/server/shared/Database/MySQLConnection.h +++ b/src/server/shared/Database/MySQLConnection.h @@ -43,6 +43,7 @@ class MySQLConnection ResultSet* Query(const char* sql); PreparedResultSet* Query(PreparedStatement* stmt); bool _Query(const char *sql, MYSQL_RES **pResult, MYSQL_FIELD **pFields, uint64* pRowCount, uint32* pFieldCount); + bool _Query(PreparedStatement* stmt, MYSQL_RES **pResult, MYSQL_FIELD **pFields, uint64* pRowCount, uint32* pFieldCount); void BeginTransaction(); void RollbackTransaction(); diff --git a/src/server/shared/Database/QueryHolder.cpp b/src/server/shared/Database/QueryHolder.cpp index 26ccce3853c..6c152f478f8 100644 --- a/src/server/shared/Database/QueryHolder.cpp +++ b/src/server/shared/Database/QueryHolder.cpp @@ -89,13 +89,6 @@ QueryResult SQLQueryHolder::GetResult(size_t index) // Don't call to this function if the index is of an ad-hoc statement if (index < m_queries.size()) { - /// the query strings are freed on the first GetResult or in the destructor - if (SQLElementData* data = &m_queries[index].first) - { - free((void*)(const_cast<char*>(data->element.query))); - data->element.query = NULL; - } - ResultSet* result = m_queries[index].second.qresult; if (!result || !result->GetRowCount()) return QueryResult(NULL); @@ -112,18 +105,10 @@ PreparedQueryResult SQLQueryHolder::GetPreparedResult(size_t index) // Don't call to this function if the index is of a prepared statement if (index < m_queries.size()) { - /// the query strings are freed on the first GetResult or in the destructor - if (SQLElementData* data = &m_queries[index].first) - { - delete data->element.stmt; - data->element.stmt = NULL; - } - PreparedResultSet* result = m_queries[index].second.presult; if (!result || !result->GetRowCount()) return PreparedQueryResult(NULL); - result->NextRow(); return PreparedQueryResult(result); } else diff --git a/src/server/shared/Database/QueryResult.cpp b/src/server/shared/Database/QueryResult.cpp index 283734b80b7..b6a4e8477c9 100644 --- a/src/server/shared/Database/QueryResult.cpp +++ b/src/server/shared/Database/QueryResult.cpp @@ -21,106 +21,29 @@ #include "DatabaseEnv.h" #include "Log.h" -ResultSet::ResultSet(MYSQL_RES *result, MYSQL_FIELD *fields, uint64 rowCount, uint32 fieldCount) -: mFieldCount(fieldCount) -, mRowCount(rowCount) -, mResult(result) +ResultSet::ResultSet(MYSQL_RES *result, MYSQL_FIELD *fields, uint64 rowCount, uint32 fieldCount) : +m_result(result), +m_fields(fields), +m_rowCount(rowCount), +m_fieldCount(fieldCount) +{ + m_currentRow = new Field[m_fieldCount]; + ASSERT(m_currentRow); +} + +PreparedResultSet::PreparedResultSet(MYSQL_STMT* stmt, MYSQL_RES *result, MYSQL_FIELD *fields, uint64 rowCount, uint32 fieldCount) : +m_rBind(NULL), +m_stmt(stmt), +m_res(result), +m_isNull(NULL), +m_length(NULL), +m_rowCount(rowCount), +m_fieldCount(fieldCount), +m_rowPosition(0) { - mCurrentRow = new Field[mFieldCount]; - ASSERT(mCurrentRow); - - for (uint32 i = 0; i < mFieldCount; i++) - mCurrentRow[i].SetType(ConvertNativeType(fields[i].type)); -} - -ResultSet::~ResultSet() -{ - EndQuery(); -} - -bool ResultSet::NextRow() -{ - MYSQL_ROW row; - - if (!mResult) - return false; - - row = mysql_fetch_row(mResult); - if (!row) - { - EndQuery(); - return false; - } - - for (uint32 i = 0; i < mFieldCount; i++) - mCurrentRow[i].SetValue(row[i]); - - return true; -} - -void ResultSet::EndQuery() -{ - if (mCurrentRow) - { - delete [] mCurrentRow; - mCurrentRow = 0; - } - - if (mResult) - { - mysql_free_result(mResult); - mResult = 0; - } -} - -enum Field::DataTypes ResultSet::ConvertNativeType(enum_field_types mysqlType) const -{ - switch (mysqlType) - { - case FIELD_TYPE_TIMESTAMP: - case FIELD_TYPE_DATE: - case FIELD_TYPE_TIME: - case FIELD_TYPE_DATETIME: - case FIELD_TYPE_YEAR: - case FIELD_TYPE_STRING: - case FIELD_TYPE_VAR_STRING: - case FIELD_TYPE_BLOB: - case FIELD_TYPE_SET: - case FIELD_TYPE_NULL: - return Field::DB_TYPE_STRING; - case FIELD_TYPE_TINY: - - case FIELD_TYPE_SHORT: - case FIELD_TYPE_LONG: - case FIELD_TYPE_INT24: - case FIELD_TYPE_LONGLONG: - case FIELD_TYPE_ENUM: - return Field::DB_TYPE_INTEGER; - case FIELD_TYPE_DECIMAL: - case FIELD_TYPE_FLOAT: - case FIELD_TYPE_DOUBLE: - return Field::DB_TYPE_FLOAT; - default: - return Field::DB_TYPE_UNKNOWN; - } -} - -void ResultBind::BindResult(uint64& num_rows) -{ - FreeBindBuffer(); - - m_res = mysql_stmt_result_metadata(m_stmt); if (!m_res) return; - m_fieldCount = mysql_stmt_field_count(m_stmt); - - if (m_stmt->bind_result_done) - { - delete[] m_stmt->bind->length; - delete[] m_stmt->bind->is_null; - } - m_rBind = new MYSQL_BIND[m_fieldCount]; m_isNull = new my_bool[m_fieldCount]; m_length = new unsigned long[m_fieldCount]; @@ -141,7 +64,7 @@ void ResultBind::BindResult(uint64& num_rows) MYSQL_FIELD* field; while ((field = mysql_fetch_field(m_res))) { - size_t size = SizeForType(field); + size_t size = Field::SizeForType(field); m_rBind[i].buffer_type = field->type; m_rBind[i].buffer = malloc(size); @@ -165,133 +88,136 @@ void ResultBind::BindResult(uint64& num_rows) return; } - num_rows = mysql_stmt_num_rows(m_stmt); -} - -void ResultBind::FreeBindBuffer() -{ - for (uint32 i = 0; i < m_fieldCount; ++i) - free (m_rBind[i].buffer); -} - -void ResultBind::CleanUp() -{ - if (m_res) - mysql_free_result(m_res); + m_rowCount = mysql_stmt_num_rows(m_stmt); - FreeBindBuffer(); - mysql_stmt_free_result(m_stmt); + m_rows.resize(m_rowCount); + while (_NextRow()) + { + m_rows[m_rowPosition] = new Field[m_fieldCount]; + for (uint64 fIndex = 0; fIndex < m_fieldCount; ++fIndex) + { + if (!*m_rBind[fIndex].is_null) + m_rows[m_rowPosition][fIndex].SetByteValue( m_rBind[fIndex].buffer, + m_rBind[fIndex].buffer_length, + m_rBind[fIndex].buffer_type, + *m_rBind[fIndex].length ); + else + switch (m_rBind[fIndex].buffer_type) + { + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_VAR_STRING: + m_rows[m_rowPosition][fIndex].SetByteValue( "", + m_rBind[fIndex].buffer_length, + m_rBind[fIndex].buffer_type, + *m_rBind[fIndex].length ); + break; + default: + m_rows[m_rowPosition][fIndex].SetByteValue( 0, + m_rBind[fIndex].buffer_length, + m_rBind[fIndex].buffer_type, + *m_rBind[fIndex].length ); + } + } + m_rowPosition++; + } + m_rowPosition = 0; - delete[] m_rBind; + /// All data is buffered, let go of mysql c api structures + CleanUp(); } -bool PreparedResultSet::GetBool(uint32 index) +ResultSet::~ResultSet() { - // TODO: Perhaps start storing data in genuine bool formats in tables - return GetUInt8(index) == 1 ? true : false; + CleanUp(); } -uint8 PreparedResultSet::GetUInt8(uint32 index) +PreparedResultSet::~PreparedResultSet() { - if (!CheckFieldIndex(index)) - return 0; - - return *reinterpret_cast<uint8*>(rbind->m_rBind[index].buffer); + for (uint64 i = 0; i < m_rowCount; ++i) + delete[] m_rows[i]; } -int8 PreparedResultSet::GetInt8(uint32 index) +bool ResultSet::NextRow() { - if (!CheckFieldIndex(index)) - return 0; - - return *reinterpret_cast<int8*>(rbind->m_rBind[index].buffer); -} + MYSQL_ROW row; -uint16 PreparedResultSet::GetUInt16(uint32 index) -{ - if (!CheckFieldIndex(index)) - return 0; + if (!m_result) + return false; - return *reinterpret_cast<uint16*>(rbind->m_rBind[index].buffer); -} + row = mysql_fetch_row(m_result); + if (!row) + { + CleanUp(); + return false; + } -int16 PreparedResultSet::GetInt16(uint32 index) -{ - if (!CheckFieldIndex(index)) - return 0; + for (uint32 i = 0; i < m_fieldCount; i++) + m_currentRow[i].SetStructuredValue(row[i], m_fields[i].type, Field::SizeForType(&m_fields[i])); - return *reinterpret_cast<int16*>(rbind->m_rBind[index].buffer); + return true; } -uint32 PreparedResultSet::GetUInt32(uint32 index) +bool PreparedResultSet::NextRow() { - if (!CheckFieldIndex(index)) - return 0; + /// Only updates the m_rowPosition so upper level code knows in which element + /// of the rows vector to look + if (++m_rowPosition >= m_rowCount) + return false; - return *reinterpret_cast<uint32*>(rbind->m_rBind[index].buffer); + return true; } -int32 PreparedResultSet::GetInt32(uint32 index) +bool PreparedResultSet::_NextRow() { - if (!CheckFieldIndex(index)) - return 0; - - return *reinterpret_cast<int32*>(rbind->m_rBind[index].buffer); -} + /// Only called in low-level code, namely the constructor + /// Will iterate over every row of data and buffer it + if (m_rowPosition >= m_rowCount) + return false; -float PreparedResultSet::GetFloat(uint32 index) -{ - if (!CheckFieldIndex(index)) - return 0; + int retval = mysql_stmt_fetch( m_stmt ); - return *reinterpret_cast<float*>(rbind->m_rBind[index].buffer); -} + if (!retval || retval == MYSQL_DATA_TRUNCATED) + retval = true; -uint64 PreparedResultSet::GetUInt64(uint32 index) -{ - if (!CheckFieldIndex(index)) - return 0; + if (retval == MYSQL_NO_DATA) + retval = false; - return *reinterpret_cast<uint64*>(rbind->m_rBind[index].buffer); + return retval; } -int64 PreparedResultSet::GetInt64(uint32 index) +void ResultSet::CleanUp() { - if (!CheckFieldIndex(index)) - return 0; + if (m_currentRow) + { + delete [] m_currentRow; + m_currentRow = NULL; + } - return *reinterpret_cast<int64*>(rbind->m_rBind[index].buffer); + if (m_result) + { + mysql_free_result(m_result); + m_result = NULL; + } } -std::string PreparedResultSet::GetString(uint32 index) +void PreparedResultSet::CleanUp() { - if (!CheckFieldIndex(index)) - return std::string(""); - - return std::string(static_cast<char const*>(rbind->m_rBind[index].buffer), *rbind->m_rBind[index].length); -} + /// More of the in our code allocated sources are deallocated by the poorly documented mysql c api + if (m_res) + mysql_free_result(m_res); -const char* PreparedResultSet::GetCString(uint32 index) -{ - if (!CheckFieldIndex(index)) - return '\0'; + FreeBindBuffer(); + mysql_stmt_free_result(m_stmt); - return static_cast<char const*>(rbind->m_rBind[index].buffer); + delete[] m_rBind; } -bool PreparedResultSet::NextRow() +void PreparedResultSet::FreeBindBuffer() { - if (row_position >= num_rows) - return false; - - int retval = mysql_stmt_fetch( rbind->m_stmt ); - - if (!retval || retval == MYSQL_DATA_TRUNCATED) - retval = true; - - if (retval == MYSQL_NO_DATA) - retval = false; - - ++row_position; - return retval; -}
\ No newline at end of file + for (uint32 i = 0; i < m_fieldCount; ++i) + free (m_rBind[i].buffer); +} diff --git a/src/server/shared/Database/QueryResult.h b/src/server/shared/Database/QueryResult.h index 22cd8bbf19e..aa088b5f121 100755 --- a/src/server/shared/Database/QueryResult.h +++ b/src/server/shared/Database/QueryResult.h @@ -39,193 +39,70 @@ class ResultSet ~ResultSet(); bool NextRow(); - - Field *Fetch() const { return mCurrentRow; } - - const Field & operator [] (int index) const { return mCurrentRow[index]; } - - uint32 GetFieldCount() const { return mFieldCount; } - uint64 GetRowCount() const { return mRowCount; } + uint64 GetRowCount() const { return m_rowCount; } + uint32 GetFieldCount() const { return m_fieldCount; } + + Field *Fetch() const { return m_currentRow; } + const Field & operator [] (uint32 index) const + { + ASSERT(index < m_rowCount); + return m_currentRow[index]; + } protected: - Field *mCurrentRow; - uint32 mFieldCount; - uint64 mRowCount; + Field *m_currentRow; + uint64 m_rowCount; + uint32 m_fieldCount; private: - enum Field::DataTypes ConvertNativeType(enum_field_types mysqlType) const; - void EndQuery(); - MYSQL_RES *mResult; - + void CleanUp(); + MYSQL_RES *m_result; + MYSQL_FIELD *m_fields; }; typedef ACE_Refcounted_Auto_Ptr<ResultSet, ACE_Null_Mutex> QueryResult; -typedef std::vector<std::string> QueryFieldNames; - -class QueryNamedResult +class PreparedResultSet { public: - explicit QueryNamedResult(ResultSet* query, QueryFieldNames const& names) : mQuery(query), mFieldNames(names) {} - ~QueryNamedResult() { delete mQuery; } - - // compatible interface with ResultSet - bool NextRow() { return mQuery->NextRow(); } - Field *Fetch() const { return mQuery->Fetch(); } - uint32 GetFieldCount() const { return mQuery->GetFieldCount(); } - uint64 GetRowCount() const { return mQuery->GetRowCount(); } - Field const& operator[] (int index) const { return (*mQuery)[index]; } + PreparedResultSet(MYSQL_STMT* stmt, MYSQL_RES *result, MYSQL_FIELD *fields, uint64 rowCount, uint32 fieldCount); + ~PreparedResultSet(); - // named access - Field const& operator[] (const std::string &name) const { return mQuery->Fetch()[GetField_idx(name)]; } - QueryFieldNames const& GetFieldNames() const { return mFieldNames; } + bool NextRow(); + uint64 GetRowCount() const { return m_rowCount; } + uint32 GetFieldCount() const { return m_fieldCount; } - uint32 GetField_idx(const std::string &name) const + Field* Fetch() const { - for (size_t idx = 0; idx < mFieldNames.size(); ++idx) - { - if(mFieldNames[idx] == name) - return idx; - } - ASSERT(false && "unknown field name"); - return uint32(-1); + ASSERT(m_rowPosition < m_rowCount); + return m_rows[m_rowPosition]; } - protected: - ResultSet *mQuery; - QueryFieldNames mFieldNames; -}; - -class ResultBind -{ - friend class PreparedResultSet; - public: - - ResultBind(MYSQL_STMT* stmt) : m_rBind(NULL), m_stmt(stmt), m_res(NULL), m_isNull(NULL), m_length(NULL), m_fieldCount(0) {} - - ~ResultBind() + const Field & operator [] (uint32 index) const { - CleanUp(); // Clean up buffer + ASSERT(m_rowPosition < m_rowCount); + ASSERT(index < m_fieldCount); + return m_rows[m_rowPosition][index]; } - void BindResult(uint64& num_rows); - protected: + uint64 m_rowCount; + uint64 m_rowPosition; + std::vector<Field*> m_rows; + uint32 m_fieldCount; + + private: MYSQL_BIND* m_rBind; MYSQL_STMT* m_stmt; MYSQL_RES* m_res; - void FreeBindBuffer(); - bool IsValidIndex(uint32 index) { return index < m_fieldCount; } - - private: - - void CleanUp(); - - size_t SizeForType(MYSQL_FIELD* field) - { - switch (field->type) - { - case MYSQL_TYPE_NULL: - return 0; - case MYSQL_TYPE_TINY: - return 1; - case MYSQL_TYPE_YEAR: - case MYSQL_TYPE_SHORT: - return 2; - case MYSQL_TYPE_INT24: - case MYSQL_TYPE_LONG: - case MYSQL_TYPE_FLOAT: - return 4; - case MYSQL_TYPE_DOUBLE: - case MYSQL_TYPE_LONGLONG: - case MYSQL_TYPE_BIT: - return 8; - - case MYSQL_TYPE_TIMESTAMP: - case MYSQL_TYPE_DATE: - case MYSQL_TYPE_TIME: - case MYSQL_TYPE_DATETIME: - return sizeof(MYSQL_TIME); - - case MYSQL_TYPE_TINY_BLOB: - case MYSQL_TYPE_MEDIUM_BLOB: - case MYSQL_TYPE_LONG_BLOB: - case MYSQL_TYPE_BLOB: - case MYSQL_TYPE_STRING: - case MYSQL_TYPE_VAR_STRING: - return field->max_length + 1; - - case MYSQL_TYPE_DECIMAL: - case MYSQL_TYPE_NEWDECIMAL: - return 64; - - case MYSQL_TYPE_GEOMETRY: - /* - Following types are not sent over the wire: - MYSQL_TYPE_ENUM: - MYSQL_TYPE_SET: - */ - default: - sLog.outSQLDriver("ResultBind::SizeForType(): invalid field type %u", uint32(field->type)); - return 0; - } - } - my_bool* m_isNull; unsigned long* m_length; - uint32 m_fieldCount; -}; -class PreparedResultSet -{ - template<class T> friend class DatabaseWorkerPool; - public: - PreparedResultSet(MYSQL_STMT* stmt) - { - num_rows = 0; - row_position = 0; - rbind = new ResultBind(stmt); - rbind->BindResult(num_rows); - } - ~PreparedResultSet() - { - delete rbind; - } - - operator bool() { return num_rows > 0; } - - bool GetBool(uint32 index); - uint8 GetUInt8(uint32 index); - int8 GetInt8(uint32 index); - uint16 GetUInt16(uint32 index); - int16 GetInt16(uint32 index); - uint32 GetUInt32(uint32 index); - int32 GetInt32(uint32 index); - uint64 GetUInt64(uint32 index); - int64 GetInt64(uint32 index); - float GetFloat(uint32 index); - std::string GetString(uint32 index); - const char* GetCString(uint32 index); - - bool NextRow(); - uint64 GetRowCount() const { return num_rows; } - - private: - bool CheckFieldIndex(uint32 index) const - { - if (!rbind->IsValidIndex(index)) - return false; - - if (rbind->m_isNull[index]) - return false; - - return true; - } + void FreeBindBuffer(); + void CleanUp(); + bool _NextRow(); - ResultBind* rbind; - uint64 row_position; - uint64 num_rows; }; typedef ACE_Refcounted_Auto_Ptr<PreparedResultSet, ACE_Null_Mutex> PreparedQueryResult; diff --git a/src/server/shared/Database/SQLStorageImpl.h b/src/server/shared/Database/SQLStorageImpl.h index 533ce7a37c3..04e905e4aa3 100644 --- a/src/server/shared/Database/SQLStorageImpl.h +++ b/src/server/shared/Database/SQLStorageImpl.h @@ -198,7 +198,7 @@ void SQLStorageLoaderBase<T>::Load(SQLStorage &store) case FT_FLOAT: storeValue((float)fields[x].GetFloat(), store, p, x, offset); break; case FT_STRING: - storeValue((char*)fields[x].GetString(), store, p, x, offset); break; + storeValue((char*)fields[x].GetCString(), store, p, x, offset); break; } ++count; }while( result->NextRow() ); |