diff options
author | Shauren <shauren.trinity@gmail.com> | 2020-03-02 22:50:56 +0100 |
---|---|---|
committer | Shauren <shauren.trinity@gmail.com> | 2020-03-02 22:51:18 +0100 |
commit | 30482038559d65648d9bcfa29cd426a6f266eeba (patch) | |
tree | e82255fc85fc55b6a06dafbb375bc904b612457f /src | |
parent | 69809d12031d055d8d0cffb5a187d9af09dd0efe (diff) |
Core/DBLayer: Use std::variant's stored type instead of relying on our own separate enum for prepared statement parameters
Diffstat (limited to 'src')
-rw-r--r-- | src/server/database/Database/MySQLPreparedStatement.cpp | 218 | ||||
-rw-r--r-- | src/server/database/Database/MySQLPreparedStatement.h | 22 | ||||
-rw-r--r-- | src/server/database/Database/PreparedStatement.cpp | 61 | ||||
-rw-r--r-- | src/server/database/Database/PreparedStatement.h | 63 |
4 files changed, 115 insertions, 249 deletions
diff --git a/src/server/database/Database/MySQLPreparedStatement.cpp b/src/server/database/Database/MySQLPreparedStatement.cpp index 1eca645326b..7f08b5f1ac6 100644 --- a/src/server/database/Database/MySQLPreparedStatement.cpp +++ b/src/server/database/Database/MySQLPreparedStatement.cpp @@ -20,7 +20,20 @@ #include "Log.h" #include "MySQLHacks.h" #include "PreparedStatement.h" -#include <sstream> + +template<typename T> +struct MySQLType { }; + +template<> struct MySQLType<uint8> : std::integral_constant<enum_field_types, MYSQL_TYPE_TINY> { }; +template<> struct MySQLType<uint16> : std::integral_constant<enum_field_types, MYSQL_TYPE_SHORT> { }; +template<> struct MySQLType<uint32> : std::integral_constant<enum_field_types, MYSQL_TYPE_LONG> { }; +template<> struct MySQLType<uint64> : std::integral_constant<enum_field_types, MYSQL_TYPE_LONGLONG> { }; +template<> struct MySQLType<int8> : std::integral_constant<enum_field_types, MYSQL_TYPE_TINY> { }; +template<> struct MySQLType<int16> : std::integral_constant<enum_field_types, MYSQL_TYPE_SHORT> { }; +template<> struct MySQLType<int32> : std::integral_constant<enum_field_types, MYSQL_TYPE_LONG> { }; +template<> struct MySQLType<int64> : std::integral_constant<enum_field_types, MYSQL_TYPE_LONGLONG> { }; +template<> struct MySQLType<float> : std::integral_constant<enum_field_types, MYSQL_TYPE_FLOAT> { }; +template<> struct MySQLType<double> : std::integral_constant<enum_field_types, MYSQL_TYPE_DOUBLE> { }; MySQLPreparedStatement::MySQLPreparedStatement(MySQLStmt* stmt, std::string queryString) : m_stmt(nullptr), m_Mstmt(stmt), m_bind(nullptr), m_queryString(std::move(queryString)) @@ -55,51 +68,10 @@ void MySQLPreparedStatement::BindParameters(PreparedStatement* stmt) uint8 pos = 0; for (PreparedStatementData const& data : stmt->GetParameters()) { - switch (data.type) + std::visit([&](auto&& param) { - case TYPE_BOOL: - setBool(pos, std::get<bool>(data.data)); - break; - case TYPE_UI8: - setUInt8(pos, std::get<uint8>(data.data)); - break; - case TYPE_UI16: - setUInt16(pos, std::get<uint16>(data.data)); - break; - case TYPE_UI32: - setUInt32(pos, std::get<uint32>(data.data)); - break; - case TYPE_I8: - setInt8(pos, std::get<int8>(data.data)); - break; - case TYPE_I16: - setInt16(pos, std::get<int16>(data.data)); - break; - case TYPE_I32: - setInt32(pos, std::get<int32>(data.data)); - break; - case TYPE_UI64: - setUInt64(pos, std::get<uint64>(data.data)); - break; - case TYPE_I64: - setInt64(pos, std::get<int64>(data.data)); - break; - case TYPE_FLOAT: - setFloat(pos, std::get<float>(data.data)); - break; - case TYPE_DOUBLE: - setDouble(pos, std::get<double>(data.data)); - break; - case TYPE_STRING: - setString(pos, std::get<std::string>(data.data)); - break; - case TYPE_BINARY: - setBinary(pos, std::get<std::vector<uint8>>(data.data)); - break; - case TYPE_NULL: - setNull(pos); - break; - } + SetParameter(pos, param); + }, data.data); ++pos; } #ifdef _DEBUG @@ -126,19 +98,6 @@ static bool ParamenterIndexAssertFail(uint32 stmtIndex, uint8 index, uint32 para return false; } -static void SetParameterValue(MYSQL_BIND* param, enum_field_types type, void const* value, uint32 len, bool isUnsigned) -{ - param->buffer_type = type; - delete[] static_cast<char*>(param->buffer); - param->buffer = new char[len]; - param->buffer_length = 0; - param->is_null_value = 0; - param->length = nullptr; // Only != NULL for strings - param->is_unsigned = isUnsigned; - - memcpy(param->buffer, value, len); -} - //- Bind on mysql level void MySQLPreparedStatement::AssertValidIndex(uint8 index) { @@ -148,7 +107,7 @@ void MySQLPreparedStatement::AssertValidIndex(uint8 index) TC_LOG_ERROR("sql.sql", "[ERROR] Prepared Statement (id: %u) trying to bind value on already bound index (%u).", m_stmt->GetIndex(), index); } -void MySQLPreparedStatement::setNull(const uint8 index) +void MySQLPreparedStatement::SetParameter(uint8 index, std::nullptr_t) { AssertValidIndex(index); m_paramsSet[index] = true; @@ -162,92 +121,30 @@ void MySQLPreparedStatement::setNull(const uint8 index) param->length = nullptr; } -void MySQLPreparedStatement::setBool(const uint8 index, const bool value) -{ - setUInt8(index, value ? 1 : 0); -} - -void MySQLPreparedStatement::setUInt8(const uint8 index, const uint8 value) +void MySQLPreparedStatement::SetParameter(uint8 index, bool value) { - AssertValidIndex(index); - m_paramsSet[index] = true; - MYSQL_BIND* param = &m_bind[index]; - SetParameterValue(param, MYSQL_TYPE_TINY, &value, sizeof(uint8), true); -} - -void MySQLPreparedStatement::setUInt16(const uint8 index, const uint16 value) -{ - AssertValidIndex(index); - m_paramsSet[index] = true; - MYSQL_BIND* param = &m_bind[index]; - SetParameterValue(param, MYSQL_TYPE_SHORT, &value, sizeof(uint16), true); -} - -void MySQLPreparedStatement::setUInt32(const uint8 index, const uint32 value) -{ - AssertValidIndex(index); - m_paramsSet[index] = true; - MYSQL_BIND* param = &m_bind[index]; - SetParameterValue(param, MYSQL_TYPE_LONG, &value, sizeof(uint32), true); -} - -void MySQLPreparedStatement::setUInt64(const uint8 index, const uint64 value) -{ - AssertValidIndex(index); - m_paramsSet[index] = true; - MYSQL_BIND* param = &m_bind[index]; - SetParameterValue(param, MYSQL_TYPE_LONGLONG, &value, sizeof(uint64), true); + SetParameter(index, uint8(value ? 1 : 0)); } -void MySQLPreparedStatement::setInt8(const uint8 index, const int8 value) +template<typename T> +void MySQLPreparedStatement::SetParameter(uint8 index, T value) { AssertValidIndex(index); m_paramsSet[index] = true; MYSQL_BIND* param = &m_bind[index]; - SetParameterValue(param, MYSQL_TYPE_TINY, &value, sizeof(int8), false); -} - -void MySQLPreparedStatement::setInt16(const uint8 index, const int16 value) -{ - AssertValidIndex(index); - m_paramsSet[index] = true; - MYSQL_BIND* param = &m_bind[index]; - SetParameterValue(param, MYSQL_TYPE_SHORT, &value, sizeof(int16), false); -} - -void MySQLPreparedStatement::setInt32(const uint8 index, const int32 value) -{ - AssertValidIndex(index); - m_paramsSet[index] = true; - MYSQL_BIND* param = &m_bind[index]; - SetParameterValue(param, MYSQL_TYPE_LONG, &value, sizeof(int32), false); -} - -void MySQLPreparedStatement::setInt64(const uint8 index, const int64 value) -{ - AssertValidIndex(index); - m_paramsSet[index] = true; - MYSQL_BIND* param = &m_bind[index]; - SetParameterValue(param, MYSQL_TYPE_LONGLONG, &value, sizeof(int64), false); -} - -void MySQLPreparedStatement::setFloat(const uint8 index, const float value) -{ - AssertValidIndex(index); - m_paramsSet[index] = true; - MYSQL_BIND* param = &m_bind[index]; - SetParameterValue(param, MYSQL_TYPE_FLOAT, &value, sizeof(float), (value > 0.0f)); -} + uint32 len = uint32(sizeof(T)); + param->buffer_type = MySQLType<T>::value; + delete[] static_cast<char*>(param->buffer); + param->buffer = new char[len]; + param->buffer_length = 0; + param->is_null_value = 0; + param->length = nullptr; // Only != NULL for strings + param->is_unsigned = std::is_unsigned_v<T>; -void MySQLPreparedStatement::setDouble(const uint8 index, const double value) -{ - AssertValidIndex(index); - m_paramsSet[index] = true; - MYSQL_BIND* param = &m_bind[index]; - SetParameterValue(param, MYSQL_TYPE_DOUBLE, &value, sizeof(double), (value > 0.0f)); + memcpy(param->buffer, &value, len); } -void MySQLPreparedStatement::setString(const uint8 index, const std::string& value) +void MySQLPreparedStatement::SetParameter(uint8 index, std::string const& value) { AssertValidIndex(index); m_paramsSet[index] = true; @@ -264,7 +161,7 @@ void MySQLPreparedStatement::setString(const uint8 index, const std::string& val memcpy(param->buffer, value.c_str(), len); } -void MySQLPreparedStatement::setBinary(const uint8 index, const std::vector<uint8>& value) +void MySQLPreparedStatement::SetParameter(uint8 index, std::vector<uint8> const& value) { AssertValidIndex(index); m_paramsSet[index] = true; @@ -289,55 +186,12 @@ std::string MySQLPreparedStatement::getQueryString() const for (PreparedStatementData const& data : m_stmt->GetParameters()) { pos = queryString.find('?', pos); - std::stringstream ss; - switch (data.type) + std::string replaceStr = std::visit([&](auto&& data) { - case TYPE_BOOL: - ss << uint16(std::get<bool>(data.data)); // stringstream will append a character with that code instead of numeric representation - break; - case TYPE_UI8: - ss << uint16(std::get<uint8>(data.data)); // stringstream will append a character with that code instead of numeric representation - break; - case TYPE_UI16: - ss << std::get<uint16>(data.data); - break; - case TYPE_UI32: - ss << std::get<uint32>(data.data); - break; - case TYPE_I8: - ss << int16(std::get<int8>(data.data)); // stringstream will append a character with that code instead of numeric representation - break; - case TYPE_I16: - ss << std::get<int16>(data.data); - break; - case TYPE_I32: - ss << std::get<int32>(data.data); - break; - case TYPE_UI64: - ss << std::get<uint64>(data.data); - break; - case TYPE_I64: - ss << std::get<int64>(data.data); - break; - case TYPE_FLOAT: - ss << std::get<float>(data.data); - break; - case TYPE_DOUBLE: - ss << std::get<double>(data.data); - break; - case TYPE_STRING: - ss << '\'' << std::get<std::string>(data.data) << '\''; - break; - case TYPE_BINARY: - ss << "BINARY"; - break; - case TYPE_NULL: - ss << "NULL"; - break; - } + return PreparedStatementData::ToString(data); + }, data.data); - std::string replaceStr = ss.str(); queryString.replace(pos, 1, replaceStr); pos += replaceStr.length(); } diff --git a/src/server/database/Database/MySQLPreparedStatement.h b/src/server/database/Database/MySQLPreparedStatement.h index cd8a234e111..27ce911a0ba 100644 --- a/src/server/database/Database/MySQLPreparedStatement.h +++ b/src/server/database/Database/MySQLPreparedStatement.h @@ -41,24 +41,16 @@ class TC_DATABASE_API MySQLPreparedStatement void BindParameters(PreparedStatement* stmt); - void setNull(const uint8 index); - void setBool(const uint8 index, const bool value); - void setUInt8(const uint8 index, const uint8 value); - void setUInt16(const uint8 index, const uint16 value); - void setUInt32(const uint8 index, const uint32 value); - void setUInt64(const uint8 index, const uint64 value); - void setInt8(const uint8 index, const int8 value); - void setInt16(const uint8 index, const int16 value); - void setInt32(const uint8 index, const int32 value); - void setInt64(const uint8 index, const int64 value); - void setFloat(const uint8 index, const float value); - void setDouble(const uint8 index, const double value); - void setString(const uint8 index, const std::string& value); - void setBinary(const uint8 index, const std::vector<uint8>& value); - uint32 GetParameterCount() const { return m_paramCount; } protected: + void SetParameter(uint8 index, std::nullptr_t); + void SetParameter(uint8 index, bool value); + template<typename T> + void SetParameter(uint8 index, T value); + void SetParameter(uint8 index, std::string const& value); + void SetParameter(uint8 index, std::vector<uint8> const& value); + MySQLStmt* GetSTMT() { return m_Mstmt; } MySQLBind* GetBind() { return m_bind; } PreparedStatement* m_stmt; diff --git a/src/server/database/Database/PreparedStatement.cpp b/src/server/database/Database/PreparedStatement.cpp index 92667e771e9..cfb4dcc3dcf 100644 --- a/src/server/database/Database/PreparedStatement.cpp +++ b/src/server/database/Database/PreparedStatement.cpp @@ -33,97 +33,84 @@ void PreparedStatement::setBool(const uint8 index, const bool value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_BOOL; } void PreparedStatement::setUInt8(const uint8 index, const uint8 value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_UI8; } void PreparedStatement::setUInt16(const uint8 index, const uint16 value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_UI16; } void PreparedStatement::setUInt32(const uint8 index, const uint32 value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_UI32; } void PreparedStatement::setUInt64(const uint8 index, const uint64 value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_UI64; } void PreparedStatement::setInt8(const uint8 index, const int8 value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_I8; } void PreparedStatement::setInt16(const uint8 index, const int16 value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_I16; } void PreparedStatement::setInt32(const uint8 index, const int32 value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_I32; } void PreparedStatement::setInt64(const uint8 index, const int64 value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_I64; } void PreparedStatement::setFloat(const uint8 index, const float value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_FLOAT; } void PreparedStatement::setDouble(const uint8 index, const double value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_DOUBLE; } void PreparedStatement::setString(const uint8 index, const std::string& value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_STRING; } void PreparedStatement::setBinary(const uint8 index, const std::vector<uint8>& value) { ASSERT(index < statement_data.size()); statement_data[index].data = value; - statement_data[index].type = TYPE_BINARY; } void PreparedStatement::setNull(const uint8 index) { ASSERT(index < statement_data.size()); - statement_data[index].type = TYPE_NULL; + statement_data[index].data = nullptr; } //- Execution @@ -159,3 +146,49 @@ bool PreparedStatementTask::Execute() return m_conn->Execute(m_stmt); } + +template<typename T> +std::string PreparedStatementData::ToString(T value) +{ + return fmt::format("{}", value); +} + +std::string PreparedStatementData::ToString(bool value) +{ + return ToString<uint32>(value); +} + +std::string PreparedStatementData::ToString(uint8 value) +{ + return ToString<uint32>(value); +} + +template std::string PreparedStatementData::ToString<uint16>(uint16); +template std::string PreparedStatementData::ToString<uint32>(uint32); +template std::string PreparedStatementData::ToString<uint64>(uint64); + +std::string PreparedStatementData::ToString(int8 value) +{ + return ToString<int32>(value); +} + +template std::string PreparedStatementData::ToString<int16>(int16); +template std::string PreparedStatementData::ToString<int32>(int32); +template std::string PreparedStatementData::ToString<int64>(int64); +template std::string PreparedStatementData::ToString<float>(float); +template std::string PreparedStatementData::ToString<double>(double); + +std::string PreparedStatementData::ToString(std::string const& value) +{ + return fmt::format("'{}'", value); +} + +std::string PreparedStatementData::ToString(std::vector<uint8> const& /*value*/) +{ + return "BINARY"; +} + +std::string PreparedStatementData::ToString(std::nullptr_t) +{ + return "NULL"; +} diff --git a/src/server/database/Database/PreparedStatement.h b/src/server/database/Database/PreparedStatement.h index b7730c6c0c1..d87bd3ba2aa 100644 --- a/src/server/database/Database/PreparedStatement.h +++ b/src/server/database/Database/PreparedStatement.h @@ -24,47 +24,34 @@ #include <vector> #include <variant> -#ifdef __APPLE__ -#undef TYPE_BOOL -#endif - -//- This enum helps us differ data held in above union -enum PreparedStatementValueType -{ - TYPE_BOOL, - TYPE_UI8, - TYPE_UI16, - TYPE_UI32, - TYPE_UI64, - TYPE_I8, - TYPE_I16, - TYPE_I32, - TYPE_I64, - TYPE_FLOAT, - TYPE_DOUBLE, - TYPE_STRING, - TYPE_BINARY, - TYPE_NULL -}; - struct PreparedStatementData { std::variant< - bool, // TYPE_BOOL - uint8, // TYPE_UI8 - uint16, // TYPE_UI16 - uint32, // TYPE_UI32 - uint64, // TYPE_UI64 - int8, // TYPE_I8 - int16, // TYPE_I16 - int32, // TYPE_UI32 - int64, // TYPE_UI64 - float, // TYPE_FLOAT - double, // TYPE_DOUBLE - std::string, // TYPE_STRING - std::vector<uint8>> // TYPE_BINARY - data; - PreparedStatementValueType type; + bool, + uint8, + uint16, + uint32, + uint64, + int8, + int16, + int32, + int64, + float, + double, + std::string, + std::vector<uint8>, + std::nullptr_t + > data; + + template<typename T> + static std::string ToString(T value); + + static std::string ToString(bool value); + static std::string ToString(uint8 value); + static std::string ToString(int8 value); + static std::string ToString(std::string const& value); + static std::string ToString(std::vector<uint8> const& value); + static std::string ToString(std::nullptr_t); }; //- Upper-level class that is used in code |