Core/DBC: Sanitize DBC loading procedure. Extra checks. Capability to load strings from DB. Load SpellName from spell_dbc.

This commit is contained in:
Treeston
2018-10-02 19:00:24 +02:00
parent 8bd8d905c5
commit 8edea4a3c2
7 changed files with 142 additions and 119 deletions

View File

@@ -0,0 +1,2 @@
--
ALTER TABLE spell_dbc CHANGE COLUMN Comment SpellName VARCHAR(100) AFTER EffectSpellClassMaskC3;

View File

@@ -221,7 +221,8 @@ static bool LoadDBC_assert_print(uint32 fsize, uint32 rsize, const std::string&
} }
template<class T> template<class T>
inline void LoadDBC(uint32& availableDbcLocales, StoreProblemList& errors, DBCStorage<T>& storage, std::string const& dbcPath, std::string const& filename, std::string const& customFormat = std::string(), std::string const& customIndexName = std::string()) inline void LoadDBC(uint32& availableDbcLocales, StoreProblemList& errors, DBCStorage<T>& storage, std::string const& dbcPath, std::string const& filename,
char const* dbTable = nullptr, char const* dbFormat = nullptr, char const* dbIndexName = nullptr)
{ {
// compatibility format and C++ structure sizes // compatibility format and C++ structure sizes
ASSERT(DBCFileLoader::GetFormatRecordSize(storage.GetFormat()) == sizeof(T) || LoadDBC_assert_print(DBCFileLoader::GetFormatRecordSize(storage.GetFormat()), sizeof(T), filename)); ASSERT(DBCFileLoader::GetFormatRecordSize(storage.GetFormat()) == sizeof(T) || LoadDBC_assert_print(DBCFileLoader::GetFormatRecordSize(storage.GetFormat()), sizeof(T), filename));
@@ -229,7 +230,7 @@ inline void LoadDBC(uint32& availableDbcLocales, StoreProblemList& errors, DBCSt
++DBCFileCount; ++DBCFileCount;
std::string dbcFilename = dbcPath + filename; std::string dbcFilename = dbcPath + filename;
if (storage.Load(dbcFilename)) if (storage.Load(dbcFilename.c_str()))
{ {
for (uint8 i = 0; i < TOTAL_LOCALES; ++i) for (uint8 i = 0; i < TOTAL_LOCALES; ++i)
{ {
@@ -241,12 +242,12 @@ inline void LoadDBC(uint32& availableDbcLocales, StoreProblemList& errors, DBCSt
localizedName.push_back('/'); localizedName.push_back('/');
localizedName.append(filename); localizedName.append(filename);
if (!storage.LoadStringsFrom(localizedName)) if (!storage.LoadStringsFrom(localizedName.c_str()))
availableDbcLocales &= ~(1 << i); // mark as not available for speedup next checks availableDbcLocales &= ~(1 << i); // mark as not available for speedup next checks
} }
if (!customFormat.empty()) if (dbTable)
storage.LoadFromDB(filename, customFormat, customIndexName); storage.LoadFromDB(dbTable, dbFormat, dbIndexName);
} }
else else
{ {
@@ -390,11 +391,11 @@ void LoadDBCStores(const std::string& dataPath)
#undef LOAD_DBC #undef LOAD_DBC
#define LOAD_DBC_EXT(store, file, dbformat, dbpk) LoadDBC(availableDbcLocales, bad_dbc_files, store, dbcPath, file, dbformat, dbpk) #define LOAD_DBC_EXT(store, file, dbtable, dbformat, dbpk) LoadDBC(availableDbcLocales, bad_dbc_files, store, dbcPath, file, dbtable, dbformat, dbpk)
LOAD_DBC_EXT(sAchievementStore, "Achievement.dbc", CustomAchievementfmt, CustomAchievementIndex); LOAD_DBC_EXT(sAchievementStore, "Achievement.dbc", "achievement_dbc", CustomAchievementfmt, CustomAchievementIndex);
LOAD_DBC_EXT(sSpellStore, "Spell.dbc", CustomSpellEntryfmt, CustomSpellEntryIndex); LOAD_DBC_EXT(sSpellStore, "Spell.dbc", "spell_dbc", CustomSpellEntryfmt, CustomSpellEntryIndex);
LOAD_DBC_EXT(sSpellDifficultyStore, "SpellDifficulty.dbc", CustomSpellDifficultyfmt, CustomSpellDifficultyIndex); LOAD_DBC_EXT(sSpellDifficultyStore, "SpellDifficulty.dbc", "spelldifficulty_dbc", CustomSpellDifficultyfmt, CustomSpellDifficultyIndex);
#undef LOAD_DBC_EXT #undef LOAD_DBC_EXT

View File

@@ -20,53 +20,75 @@
#include "DatabaseEnv.h" #include "DatabaseEnv.h"
#include "Errors.h" #include "Errors.h"
#include "Log.h" #include "Log.h"
#include "StringFormat.h"
#include <sstream> #include <sstream>
DBCDatabaseLoader::DBCDatabaseLoader(std::string const& storageName, std::string const& dbFormatString, std::string const& primaryKey, char const* dbcFormatString) DBCDatabaseLoader::DBCDatabaseLoader(char const* tableName, char const* dbFormatString, char const* primaryKey, char const* dbcFormatString)
: _formatString(dbFormatString), _indexName(primaryKey), _sqlTableName(storageName), _dbcFormat(dbcFormatString), _sqlIndexPos(0), _recordSize(0) : _sqlTableName(tableName), _formatString(dbFormatString), _indexName(primaryKey), _dbcFormat(dbcFormatString), _sqlIndexPos(0), _recordSize(0)
{ {
// Convert dbc file name to sql table name
std::transform(_sqlTableName.begin(), _sqlTableName.end(), _sqlTableName.begin(), ::tolower);
for (char& c : _sqlTableName)
if (c == '.')
c = '_';
// Get sql index position // Get sql index position
int32 indexPos = -1; int32 indexPos = -1;
_recordSize = DBCFileLoader::GetFormatRecordSize(_dbcFormat, &indexPos); _recordSize = DBCFileLoader::GetFormatRecordSize(_dbcFormat, &indexPos);
ASSERT(indexPos >= 0); ASSERT(indexPos >= 0);
ASSERT(_recordSize); ASSERT(_recordSize);
uint32 uindexPos = uint32(indexPos); uint32 uIndexPos = uint32(indexPos);
for (uint32 x = 0; x < _formatString.size(); ++x) char const* fmt = _formatString;
while (uIndexPos)
{ {
// Count only fields present in sql switch (*fmt)
if (_formatString[x] == FT_SQL_PRESENT)
{ {
if (x == uindexPos) case FT_SQL_PRESENT:
++_sqlIndexPos;
case FT_SQL_ABSENT:
break;
default:
ASSERT(false, "Invalid DB format string for '%s'", tableName);
break; break;
++_sqlIndexPos;
} }
--uIndexPos;
++fmt;
} }
ASSERT(*fmt == FT_SQL_PRESENT, "Index column not present in format string for '%s'", tableName);
} }
static char const* nullStr = ""; static char const* nullStr = "";
struct CleanupStruct
{
static char* Clone(std::string const& str)
{
char* ptr = new char[str.size() + 1];
memcpy(ptr, str.c_str(), str.size() + 1);
_instance()._managed.push_back(ptr);
return ptr;
}
~CleanupStruct()
{
for (char* ptr : _managed)
delete[] ptr;
_managed.clear();
}
private:
static CleanupStruct& _instance() { static CleanupStruct c; return c; }
std::vector<char*> _managed;
};
char* DBCDatabaseLoader::Load(uint32& records, char**& indexTable) char* DBCDatabaseLoader::Load(uint32& records, char**& indexTable)
{ {
std::ostringstream queryBuilder; std::string query = Trinity::StringFormat("SELECT * FROM %s ORDER BY %s DESC;", _sqlTableName, _indexName);
queryBuilder << "SELECT * FROM " << _sqlTableName
<< " ORDER BY " << _indexName << " DESC;";
// no error if empty set // no error if empty set
QueryResult result = WorldDatabase.Query(queryBuilder.str().c_str()); QueryResult result = WorldDatabase.Query(query.c_str());
if (!result) if (!result)
return nullptr; return nullptr;
// Check if sql index pos is valid // Check if sql index pos is valid
if (int32(result->GetFieldCount() - 1) < _sqlIndexPos) if (int32(result->GetFieldCount() - 1) < _sqlIndexPos)
{ {
ASSERT(false, "Invalid index pos for dbc:'%s'", _sqlTableName.c_str()); ASSERT(false, "Invalid index pos for dbc: '%s'", _sqlTableName);
return nullptr; return nullptr;
} }
@@ -90,7 +112,6 @@ char* DBCDatabaseLoader::Load(uint32& records, char**& indexTable)
do do
{ {
Field* fields = result->Fetch(); Field* fields = result->Fetch();
uint32 offset = 0;
uint32 indexValue = fields[_sqlIndexPos].GetUInt32(); uint32 indexValue = fields[_sqlIndexPos].GetUInt32();
@@ -103,82 +124,81 @@ char* DBCDatabaseLoader::Load(uint32& records, char**& indexTable)
else else
{ {
// Attempt to overwrite existing data // Attempt to overwrite existing data
ASSERT(false, "Index %d already exists in dbc:'%s'", indexValue, _sqlTableName.c_str()); ASSERT(false, "Index %d already exists in dbc:'%s'", indexValue, _sqlTableName);
return nullptr; return nullptr;
} }
uint32 columnNumber = 0; uint32 dataOffset = 0;
uint32 sqlColumnNumber = 0; uint32 sqlColumnNumber = 0;
char const* dbcFormat = _dbcFormat;
for (; columnNumber < _formatString.size(); ++columnNumber) char const* sqlFormat = _formatString;
for (; (*dbcFormat || *sqlFormat); ++dbcFormat, ++sqlFormat)
{ {
if (_formatString[columnNumber] == FT_SQL_ABSENT) if (!*dbcFormat || !*sqlFormat)
{ {
switch (_dbcFormat[columnNumber]) ASSERT(false, "DB and DBC format strings do not have the same length for '%s'", _sqlTableName);
{ return nullptr;
case FT_FLOAT:
*reinterpret_cast<float*>(&dataValue[offset]) = 0.0f;
offset += 4;
break;
case FT_IND:
case FT_INT:
*reinterpret_cast<uint32*>(&dataValue[offset]) = uint32(0);
offset += 4;
break;
case FT_BYTE:
*reinterpret_cast<uint8*>(&dataValue[offset]) = uint8(0);
offset += 1;
break;
case FT_STRING:
*reinterpret_cast<char**>(&dataValue[offset]) = const_cast<char*>(nullStr);
offset += sizeof(char*);
break;
}
} }
else if (_formatString[columnNumber] == FT_SQL_PRESENT) if (!*dbcFormat)
{
bool validSqlColumn = true;
switch (_dbcFormat[columnNumber])
{
case FT_FLOAT:
*reinterpret_cast<float*>(&dataValue[offset]) = fields[sqlColumnNumber].GetFloat();
offset += 4;
break;
case FT_IND:
case FT_INT:
*reinterpret_cast<uint32*>(&dataValue[offset]) = fields[sqlColumnNumber].GetUInt32();
offset += 4;
break;
case FT_BYTE:
*reinterpret_cast<uint8*>(&dataValue[offset]) = fields[sqlColumnNumber].GetUInt8();
offset += 1;
break;
case FT_STRING:
ASSERT(false, "Unsupported data type in table '%s' at char %d", _sqlTableName.c_str(), columnNumber);
return nullptr;
case FT_SORT:
break;
default:
validSqlColumn = false;
break;
}
if (validSqlColumn && (columnNumber != (_formatString.size() - 1)))
sqlColumnNumber++;
}
else
{
ASSERT(false, "Incorrect sql format string '%s' at char %d", _sqlTableName.c_str(), columnNumber);
break; break;
switch (*sqlFormat)
{
case FT_SQL_PRESENT:
switch (*dbcFormat)
{
case FT_FLOAT:
*reinterpret_cast<float*>(&dataValue[dataOffset]) = fields[sqlColumnNumber].GetFloat();
dataOffset += sizeof(float);
break;
case FT_IND:
case FT_INT:
*reinterpret_cast<uint32*>(&dataValue[dataOffset]) = fields[sqlColumnNumber].GetUInt32();
dataOffset += sizeof(uint32);
break;
case FT_BYTE:
*reinterpret_cast<uint8*>(&dataValue[dataOffset]) = fields[sqlColumnNumber].GetUInt8();
dataOffset += sizeof(uint8);
break;
case FT_STRING:
*reinterpret_cast<char**>(&dataValue[dataOffset]) = CleanupStruct::Clone(fields[sqlColumnNumber].GetString());
dataOffset += sizeof(char*);
case FT_SORT:
break;
default:
ASSERT(false, "Unsupported data type '%c' marked present in table '%s'", *dbcFormat, _sqlTableName);
return nullptr;
}
++sqlColumnNumber;
break;
case FT_SQL_ABSENT:
switch (*dbcFormat)
{
case FT_FLOAT:
*reinterpret_cast<float*>(&dataValue[dataOffset]) = 0.0f;
dataOffset += 4;
break;
case FT_IND:
case FT_INT:
*reinterpret_cast<uint32*>(&dataValue[dataOffset]) = uint32(0);
dataOffset += 4;
break;
case FT_BYTE:
*reinterpret_cast<uint8*>(&dataValue[dataOffset]) = uint8(0);
dataOffset += 1;
break;
case FT_STRING:
*reinterpret_cast<char**>(&dataValue[dataOffset]) = const_cast<char*>(nullStr);
dataOffset += sizeof(char*);
break;
}
break;
default:
ASSERT(false, "Invalid DB format string for '%s'", _sqlTableName);
return nullptr;
} }
} }
ASSERT(sqlColumnNumber == result->GetFieldCount(), "SQL format string does not match database for table: '%s'", _sqlTableName);
if (sqlColumnNumber != (result->GetFieldCount() - 1)) ASSERT(dataOffset == _recordSize);
{
ASSERT(false, "SQL and DBC format strings are not matching for table: '%s'", _sqlTableName.c_str());
return nullptr;
}
ASSERT(offset == _recordSize);
} while (result->NextRow()); } while (result->NextRow());
ASSERT(newRecords == result->GetRowCount()); ASSERT(newRecords == result->GetRowCount());

View File

@@ -24,14 +24,14 @@
struct TC_SHARED_API DBCDatabaseLoader struct TC_SHARED_API DBCDatabaseLoader
{ {
DBCDatabaseLoader(std::string const& storageName, std::string const& dbFormatString, std::string const& primaryKey, char const* dbcFormatString); DBCDatabaseLoader(char const* dbTable, char const* dbFormatString, char const* index, char const* dbcFormatString);
char* Load(uint32& records, char**& indexTable); char* Load(uint32& records, char**& indexTable);
private: private:
std::string const& _formatString; char const* _sqlTableName;
std::string const& _indexName; char const* _formatString;
std::string _sqlTableName; char const* _indexName;
char const* _dbcFormat; char const* _dbcFormat;
int32 _sqlIndexPos; int32 _sqlIndexPos;
uint32 _recordSize; uint32 _recordSize;

View File

@@ -30,13 +30,13 @@ DBCStorageBase::~DBCStorageBase()
delete[] strings; delete[] strings;
} }
bool DBCStorageBase::Load(std::string const& path, char**& indexTable) bool DBCStorageBase::Load(char const* path, char**& indexTable)
{ {
indexTable = nullptr; indexTable = nullptr;
DBCFileLoader dbc; DBCFileLoader dbc;
// Check if load was sucessful, only then continue // Check if load was sucessful, only then continue
if (!dbc.Load(path.c_str(), _fileFormat)) if (!dbc.Load(path, _fileFormat))
return false; return false;
_fieldCount = dbc.GetCols(); _fieldCount = dbc.GetCols();
@@ -52,7 +52,7 @@ bool DBCStorageBase::Load(std::string const& path, char**& indexTable)
return indexTable != nullptr; return indexTable != nullptr;
} }
bool DBCStorageBase::LoadStringsFrom(std::string const& path, char** indexTable) bool DBCStorageBase::LoadStringsFrom(char const* path, char** indexTable)
{ {
// DBC must be already loaded using Load // DBC must be already loaded using Load
if (!indexTable) if (!indexTable)
@@ -60,7 +60,7 @@ bool DBCStorageBase::LoadStringsFrom(std::string const& path, char** indexTable)
DBCFileLoader dbc; DBCFileLoader dbc;
// Check if load was successful, only then continue // Check if load was successful, only then continue
if (!dbc.Load(path.c_str(), _fileFormat)) if (!dbc.Load(path, _fileFormat))
return false; return false;
// load strings from another locale dbc data // load strings from another locale dbc data
@@ -70,7 +70,7 @@ bool DBCStorageBase::LoadStringsFrom(std::string const& path, char** indexTable)
return true; return true;
} }
void DBCStorageBase::LoadFromDB(std::string const& path, std::string const& dbFormat, std::string const& primaryKey, char**& indexTable) void DBCStorageBase::LoadFromDB(char const* table, char const* format, char const* index, char**& indexTable)
{ {
_dataTableEx = DBCDatabaseLoader(path, dbFormat, primaryKey, _fileFormat).Load(_indexTableSize, indexTable); _dataTableEx = DBCDatabaseLoader(table, format, index, _fileFormat).Load(_indexTableSize, indexTable);
} }

View File

@@ -33,14 +33,14 @@ class TC_SHARED_API DBCStorageBase
char const* GetFormat() const { return _fileFormat; } char const* GetFormat() const { return _fileFormat; }
uint32 GetFieldCount() const { return _fieldCount; } uint32 GetFieldCount() const { return _fieldCount; }
virtual bool Load(std::string const& path) = 0; virtual bool Load(char const* path) = 0;
virtual bool LoadStringsFrom(std::string const& path) = 0; virtual bool LoadStringsFrom(char const* path) = 0;
virtual void LoadFromDB(std::string const& path, std::string const& dbFormat, std::string const& primaryKey) = 0; virtual void LoadFromDB(char const* table, char const* format, char const* index) = 0;
protected: protected:
bool Load(std::string const& path, char**& indexTable); bool Load(char const* path, char**& indexTable);
bool LoadStringsFrom(std::string const& path, char** indexTable); bool LoadStringsFrom(char const* path, char** indexTable);
void LoadFromDB(std::string const& path, std::string const& dbFormat, std::string const& primaryKey, char**& indexTable); void LoadFromDB(char const* table, char const* format, char const* index, char**& indexTable);
uint32 _fieldCount; uint32 _fieldCount;
char const* _fileFormat; char const* _fileFormat;
@@ -71,19 +71,19 @@ class DBCStorage : public DBCStorageBase
uint32 GetNumRows() const { return _indexTableSize; } uint32 GetNumRows() const { return _indexTableSize; }
bool Load(std::string const& path) override bool Load(char const* path) override
{ {
return DBCStorageBase::Load(path, _indexTable.AsChar); return DBCStorageBase::Load(path, _indexTable.AsChar);
} }
bool LoadStringsFrom(std::string const& path) override bool LoadStringsFrom(char const* path) override
{ {
return DBCStorageBase::LoadStringsFrom(path, _indexTable.AsChar); return DBCStorageBase::LoadStringsFrom(path, _indexTable.AsChar);
} }
void LoadFromDB(std::string const& path, std::string const& dbFormat, std::string const& primaryKey) override void LoadFromDB(char const* table, char const* format, char const* index) override
{ {
DBCStorageBase::LoadFromDB(path, dbFormat, primaryKey, _indexTable.AsChar); DBCStorageBase::LoadFromDB(table, format, index, _indexTable.AsChar);
} }
iterator begin() { return iterator(_indexTable.AsT, _indexTableSize); } iterator begin() { return iterator(_indexTable.AsT, _indexTableSize); }

View File

@@ -19,7 +19,7 @@
#ifndef TRINITY_DBCSFRM_H #ifndef TRINITY_DBCSFRM_H
#define TRINITY_DBCSFRM_H #define TRINITY_DBCSFRM_H
char constexpr Achievementfmt[] = "niixssssssssssssssssxxxxxxxxxxxxxxxxxxiixixxxxxxxxxxxxxxxxxxii"; char constexpr Achievementfmt[] = "niixssssssssssssssssxxxxxxxxxxxxxxxxxxiixixxxxxxxxxxxxxxxxxxii";
char constexpr CustomAchievementfmt[] = "pppaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaapapaaaaaaaaaaaaaaaaaapp"; char constexpr CustomAchievementfmt[] = "pppaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaapapaaaaaaaaaaaaaaaaaapp";
char constexpr CustomAchievementIndex[] = "ID"; char constexpr CustomAchievementIndex[] = "ID";
char constexpr AchievementCriteriafmt[] = "niiiiiiiixxxxxxxxxxxxxxxxxiiiix"; char constexpr AchievementCriteriafmt[] = "niiiiiiiixxxxxxxxxxxxxxxxxiiiix";
@@ -109,12 +109,12 @@ char constexpr SkillTiersfmt[] = "nxxxxxxxxxxxxxxxxiiiiiiiiiiiiiiii";
char constexpr SoundEntriesfmt[] = "nxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; char constexpr SoundEntriesfmt[] = "nxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
char constexpr SpellCastTimefmt[] = "nixx"; char constexpr SpellCastTimefmt[] = "nixx";
char constexpr SpellCategoryfmt[] = "ni"; char constexpr SpellCategoryfmt[] = "ni";
char constexpr SpellDifficultyfmt[] = "niiii"; char constexpr SpellDifficultyfmt[] = "niiii";
char constexpr CustomSpellDifficultyfmt[] = "ppppp"; char constexpr CustomSpellDifficultyfmt[] = "ppppp";
char constexpr CustomSpellDifficultyIndex[] = "id"; char constexpr CustomSpellDifficultyIndex[] = "id";
char constexpr SpellDurationfmt[] = "niii"; char constexpr SpellDurationfmt[] = "niii";
char constexpr SpellEntryfmt[] = "niiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiifxiiiiiiiiiiiiiiiiiiiiiiiiiiiifffiiiiiiiiiiiiiiiiiiiiifffiiiiiiiiiiiiiiifffiiiiiiiiiiiiiissssssssssssssssxssssssssssssssssxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxiiiiiiiiiiixfffxxxiiiiixxfffxx"; char constexpr SpellEntryfmt[] = "niiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiifxiiiiiiiiiiiiiiiiiiiiiiiiiiiifffiiiiiiiiiiiiiiiiiiiiifffiiiiiiiiiiiiiiifffiiiiiiiiiiiiiissssssssssssssssxssssssssssssssssxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxiiiiiiiiiiixfffxxxiiiiixxfffxx";
char constexpr CustomSpellEntryfmt[] = "papppppppppppapapaaaaaaaaaaapaaapapppppppaaaaapaapaaaaaaaaaaaaaaaaaappppppppppppppppppppppppppppppppppppaaappppppppppppaaapppppppppaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaappppppppapppaaaaappaaaaaaa"; char constexpr CustomSpellEntryfmt[] = "papppppppppppapapaaaaaaaaaaapaaapapppppppaaaaapaapaaaaaaaaaaaaaaaaaappppppppppppppppppppppppppppppppppppaaappppppppppppaaapppppppppaaaaapaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaappppppppapppaaaaappaaaaaaaa";
char constexpr CustomSpellEntryIndex[] = "Id"; char constexpr CustomSpellEntryIndex[] = "Id";
char constexpr SpellFocusObjectfmt[] = "nxxxxxxxxxxxxxxxxx"; char constexpr SpellFocusObjectfmt[] = "nxxxxxxxxxxxxxxxxx";
char constexpr SpellItemEnchantmentfmt[] = "nxiiiiiixxxiiissssssssssssssssxiiiiiii"; char constexpr SpellItemEnchantmentfmt[] = "nxiiiiiixxxiiissssssssssssssssxiiiiiii";