Core/DBC: Sanitize DBC loading procedure. Extra checks. Capability to load strings from DB. Load SpellName from spell_dbc.

This commit is contained in:
Treeston
2018-10-02 19:00:24 +02:00
parent 8bd8d905c5
commit 8edea4a3c2
7 changed files with 142 additions and 119 deletions

View File

@@ -0,0 +1,2 @@
--
ALTER TABLE spell_dbc CHANGE COLUMN Comment SpellName VARCHAR(100) AFTER EffectSpellClassMaskC3;

View File

@@ -221,7 +221,8 @@ static bool LoadDBC_assert_print(uint32 fsize, uint32 rsize, const std::string&
}
template<class T>
inline void LoadDBC(uint32& availableDbcLocales, StoreProblemList& errors, DBCStorage<T>& storage, std::string const& dbcPath, std::string const& filename, std::string const& customFormat = std::string(), std::string const& customIndexName = std::string())
inline void LoadDBC(uint32& availableDbcLocales, StoreProblemList& errors, DBCStorage<T>& storage, std::string const& dbcPath, std::string const& filename,
char const* dbTable = nullptr, char const* dbFormat = nullptr, char const* dbIndexName = nullptr)
{
// compatibility format and C++ structure sizes
ASSERT(DBCFileLoader::GetFormatRecordSize(storage.GetFormat()) == sizeof(T) || LoadDBC_assert_print(DBCFileLoader::GetFormatRecordSize(storage.GetFormat()), sizeof(T), filename));
@@ -229,7 +230,7 @@ inline void LoadDBC(uint32& availableDbcLocales, StoreProblemList& errors, DBCSt
++DBCFileCount;
std::string dbcFilename = dbcPath + filename;
if (storage.Load(dbcFilename))
if (storage.Load(dbcFilename.c_str()))
{
for (uint8 i = 0; i < TOTAL_LOCALES; ++i)
{
@@ -241,12 +242,12 @@ inline void LoadDBC(uint32& availableDbcLocales, StoreProblemList& errors, DBCSt
localizedName.push_back('/');
localizedName.append(filename);
if (!storage.LoadStringsFrom(localizedName))
if (!storage.LoadStringsFrom(localizedName.c_str()))
availableDbcLocales &= ~(1 << i); // mark as not available for speedup next checks
}
if (!customFormat.empty())
storage.LoadFromDB(filename, customFormat, customIndexName);
if (dbTable)
storage.LoadFromDB(dbTable, dbFormat, dbIndexName);
}
else
{
@@ -390,11 +391,11 @@ void LoadDBCStores(const std::string& dataPath)
#undef LOAD_DBC
#define LOAD_DBC_EXT(store, file, dbformat, dbpk) LoadDBC(availableDbcLocales, bad_dbc_files, store, dbcPath, file, dbformat, dbpk)
#define LOAD_DBC_EXT(store, file, dbtable, dbformat, dbpk) LoadDBC(availableDbcLocales, bad_dbc_files, store, dbcPath, file, dbtable, dbformat, dbpk)
LOAD_DBC_EXT(sAchievementStore, "Achievement.dbc", CustomAchievementfmt, CustomAchievementIndex);
LOAD_DBC_EXT(sSpellStore, "Spell.dbc", CustomSpellEntryfmt, CustomSpellEntryIndex);
LOAD_DBC_EXT(sSpellDifficultyStore, "SpellDifficulty.dbc", CustomSpellDifficultyfmt, CustomSpellDifficultyIndex);
LOAD_DBC_EXT(sAchievementStore, "Achievement.dbc", "achievement_dbc", CustomAchievementfmt, CustomAchievementIndex);
LOAD_DBC_EXT(sSpellStore, "Spell.dbc", "spell_dbc", CustomSpellEntryfmt, CustomSpellEntryIndex);
LOAD_DBC_EXT(sSpellDifficultyStore, "SpellDifficulty.dbc", "spelldifficulty_dbc", CustomSpellDifficultyfmt, CustomSpellDifficultyIndex);
#undef LOAD_DBC_EXT

View File

@@ -20,53 +20,75 @@
#include "DatabaseEnv.h"
#include "Errors.h"
#include "Log.h"
#include "StringFormat.h"
#include <sstream>
DBCDatabaseLoader::DBCDatabaseLoader(std::string const& storageName, std::string const& dbFormatString, std::string const& primaryKey, char const* dbcFormatString)
: _formatString(dbFormatString), _indexName(primaryKey), _sqlTableName(storageName), _dbcFormat(dbcFormatString), _sqlIndexPos(0), _recordSize(0)
DBCDatabaseLoader::DBCDatabaseLoader(char const* tableName, char const* dbFormatString, char const* primaryKey, char const* dbcFormatString)
: _sqlTableName(tableName), _formatString(dbFormatString), _indexName(primaryKey), _dbcFormat(dbcFormatString), _sqlIndexPos(0), _recordSize(0)
{
// Convert dbc file name to sql table name
std::transform(_sqlTableName.begin(), _sqlTableName.end(), _sqlTableName.begin(), ::tolower);
for (char& c : _sqlTableName)
if (c == '.')
c = '_';
// Get sql index position
int32 indexPos = -1;
_recordSize = DBCFileLoader::GetFormatRecordSize(_dbcFormat, &indexPos);
ASSERT(indexPos >= 0);
ASSERT(_recordSize);
uint32 uindexPos = uint32(indexPos);
for (uint32 x = 0; x < _formatString.size(); ++x)
uint32 uIndexPos = uint32(indexPos);
char const* fmt = _formatString;
while (uIndexPos)
{
// Count only fields present in sql
if (_formatString[x] == FT_SQL_PRESENT)
switch (*fmt)
{
if (x == uindexPos)
case FT_SQL_PRESENT:
++_sqlIndexPos;
case FT_SQL_ABSENT:
break;
default:
ASSERT(false, "Invalid DB format string for '%s'", tableName);
break;
++_sqlIndexPos;
}
--uIndexPos;
++fmt;
}
ASSERT(*fmt == FT_SQL_PRESENT, "Index column not present in format string for '%s'", tableName);
}
static char const* nullStr = "";
struct CleanupStruct
{
static char* Clone(std::string const& str)
{
char* ptr = new char[str.size() + 1];
memcpy(ptr, str.c_str(), str.size() + 1);
_instance()._managed.push_back(ptr);
return ptr;
}
~CleanupStruct()
{
for (char* ptr : _managed)
delete[] ptr;
_managed.clear();
}
private:
static CleanupStruct& _instance() { static CleanupStruct c; return c; }
std::vector<char*> _managed;
};
char* DBCDatabaseLoader::Load(uint32& records, char**& indexTable)
{
std::ostringstream queryBuilder;
queryBuilder << "SELECT * FROM " << _sqlTableName
<< " ORDER BY " << _indexName << " DESC;";
std::string query = Trinity::StringFormat("SELECT * FROM %s ORDER BY %s DESC;", _sqlTableName, _indexName);
// no error if empty set
QueryResult result = WorldDatabase.Query(queryBuilder.str().c_str());
QueryResult result = WorldDatabase.Query(query.c_str());
if (!result)
return nullptr;
// Check if sql index pos is valid
if (int32(result->GetFieldCount() - 1) < _sqlIndexPos)
{
ASSERT(false, "Invalid index pos for dbc:'%s'", _sqlTableName.c_str());
ASSERT(false, "Invalid index pos for dbc: '%s'", _sqlTableName);
return nullptr;
}
@@ -90,7 +112,6 @@ char* DBCDatabaseLoader::Load(uint32& records, char**& indexTable)
do
{
Field* fields = result->Fetch();
uint32 offset = 0;
uint32 indexValue = fields[_sqlIndexPos].GetUInt32();
@@ -103,82 +124,81 @@ char* DBCDatabaseLoader::Load(uint32& records, char**& indexTable)
else
{
// Attempt to overwrite existing data
ASSERT(false, "Index %d already exists in dbc:'%s'", indexValue, _sqlTableName.c_str());
ASSERT(false, "Index %d already exists in dbc:'%s'", indexValue, _sqlTableName);
return nullptr;
}
uint32 columnNumber = 0;
uint32 dataOffset = 0;
uint32 sqlColumnNumber = 0;
for (; columnNumber < _formatString.size(); ++columnNumber)
char const* dbcFormat = _dbcFormat;
char const* sqlFormat = _formatString;
for (; (*dbcFormat || *sqlFormat); ++dbcFormat, ++sqlFormat)
{
if (_formatString[columnNumber] == FT_SQL_ABSENT)
if (!*dbcFormat || !*sqlFormat)
{
switch (_dbcFormat[columnNumber])
{
case FT_FLOAT:
*reinterpret_cast<float*>(&dataValue[offset]) = 0.0f;
offset += 4;
break;
case FT_IND:
case FT_INT:
*reinterpret_cast<uint32*>(&dataValue[offset]) = uint32(0);
offset += 4;
break;
case FT_BYTE:
*reinterpret_cast<uint8*>(&dataValue[offset]) = uint8(0);
offset += 1;
break;
case FT_STRING:
*reinterpret_cast<char**>(&dataValue[offset]) = const_cast<char*>(nullStr);
offset += sizeof(char*);
break;
}
ASSERT(false, "DB and DBC format strings do not have the same length for '%s'", _sqlTableName);
return nullptr;
}
else if (_formatString[columnNumber] == FT_SQL_PRESENT)
{
bool validSqlColumn = true;
switch (_dbcFormat[columnNumber])
{
case FT_FLOAT:
*reinterpret_cast<float*>(&dataValue[offset]) = fields[sqlColumnNumber].GetFloat();
offset += 4;
break;
case FT_IND:
case FT_INT:
*reinterpret_cast<uint32*>(&dataValue[offset]) = fields[sqlColumnNumber].GetUInt32();
offset += 4;
break;
case FT_BYTE:
*reinterpret_cast<uint8*>(&dataValue[offset]) = fields[sqlColumnNumber].GetUInt8();
offset += 1;
break;
case FT_STRING:
ASSERT(false, "Unsupported data type in table '%s' at char %d", _sqlTableName.c_str(), columnNumber);
return nullptr;
case FT_SORT:
break;
default:
validSqlColumn = false;
break;
}
if (validSqlColumn && (columnNumber != (_formatString.size() - 1)))
sqlColumnNumber++;
}
else
{
ASSERT(false, "Incorrect sql format string '%s' at char %d", _sqlTableName.c_str(), columnNumber);
if (!*dbcFormat)
break;
switch (*sqlFormat)
{
case FT_SQL_PRESENT:
switch (*dbcFormat)
{
case FT_FLOAT:
*reinterpret_cast<float*>(&dataValue[dataOffset]) = fields[sqlColumnNumber].GetFloat();
dataOffset += sizeof(float);
break;
case FT_IND:
case FT_INT:
*reinterpret_cast<uint32*>(&dataValue[dataOffset]) = fields[sqlColumnNumber].GetUInt32();
dataOffset += sizeof(uint32);
break;
case FT_BYTE:
*reinterpret_cast<uint8*>(&dataValue[dataOffset]) = fields[sqlColumnNumber].GetUInt8();
dataOffset += sizeof(uint8);
break;
case FT_STRING:
*reinterpret_cast<char**>(&dataValue[dataOffset]) = CleanupStruct::Clone(fields[sqlColumnNumber].GetString());
dataOffset += sizeof(char*);
case FT_SORT:
break;
default:
ASSERT(false, "Unsupported data type '%c' marked present in table '%s'", *dbcFormat, _sqlTableName);
return nullptr;
}
++sqlColumnNumber;
break;
case FT_SQL_ABSENT:
switch (*dbcFormat)
{
case FT_FLOAT:
*reinterpret_cast<float*>(&dataValue[dataOffset]) = 0.0f;
dataOffset += 4;
break;
case FT_IND:
case FT_INT:
*reinterpret_cast<uint32*>(&dataValue[dataOffset]) = uint32(0);
dataOffset += 4;
break;
case FT_BYTE:
*reinterpret_cast<uint8*>(&dataValue[dataOffset]) = uint8(0);
dataOffset += 1;
break;
case FT_STRING:
*reinterpret_cast<char**>(&dataValue[dataOffset]) = const_cast<char*>(nullStr);
dataOffset += sizeof(char*);
break;
}
break;
default:
ASSERT(false, "Invalid DB format string for '%s'", _sqlTableName);
return nullptr;
}
}
if (sqlColumnNumber != (result->GetFieldCount() - 1))
{
ASSERT(false, "SQL and DBC format strings are not matching for table: '%s'", _sqlTableName.c_str());
return nullptr;
}
ASSERT(offset == _recordSize);
ASSERT(sqlColumnNumber == result->GetFieldCount(), "SQL format string does not match database for table: '%s'", _sqlTableName);
ASSERT(dataOffset == _recordSize);
} while (result->NextRow());
ASSERT(newRecords == result->GetRowCount());

View File

@@ -24,14 +24,14 @@
struct TC_SHARED_API DBCDatabaseLoader
{
DBCDatabaseLoader(std::string const& storageName, std::string const& dbFormatString, std::string const& primaryKey, char const* dbcFormatString);
DBCDatabaseLoader(char const* dbTable, char const* dbFormatString, char const* index, char const* dbcFormatString);
char* Load(uint32& records, char**& indexTable);
private:
std::string const& _formatString;
std::string const& _indexName;
std::string _sqlTableName;
char const* _sqlTableName;
char const* _formatString;
char const* _indexName;
char const* _dbcFormat;
int32 _sqlIndexPos;
uint32 _recordSize;

View File

@@ -30,13 +30,13 @@ DBCStorageBase::~DBCStorageBase()
delete[] strings;
}
bool DBCStorageBase::Load(std::string const& path, char**& indexTable)
bool DBCStorageBase::Load(char const* path, char**& indexTable)
{
indexTable = nullptr;
DBCFileLoader dbc;
// Check if load was sucessful, only then continue
if (!dbc.Load(path.c_str(), _fileFormat))
if (!dbc.Load(path, _fileFormat))
return false;
_fieldCount = dbc.GetCols();
@@ -52,7 +52,7 @@ bool DBCStorageBase::Load(std::string const& path, char**& indexTable)
return indexTable != nullptr;
}
bool DBCStorageBase::LoadStringsFrom(std::string const& path, char** indexTable)
bool DBCStorageBase::LoadStringsFrom(char const* path, char** indexTable)
{
// DBC must be already loaded using Load
if (!indexTable)
@@ -60,7 +60,7 @@ bool DBCStorageBase::LoadStringsFrom(std::string const& path, char** indexTable)
DBCFileLoader dbc;
// Check if load was successful, only then continue
if (!dbc.Load(path.c_str(), _fileFormat))
if (!dbc.Load(path, _fileFormat))
return false;
// load strings from another locale dbc data
@@ -70,7 +70,7 @@ bool DBCStorageBase::LoadStringsFrom(std::string const& path, char** indexTable)
return true;
}
void DBCStorageBase::LoadFromDB(std::string const& path, std::string const& dbFormat, std::string const& primaryKey, char**& indexTable)
void DBCStorageBase::LoadFromDB(char const* table, char const* format, char const* index, char**& indexTable)
{
_dataTableEx = DBCDatabaseLoader(path, dbFormat, primaryKey, _fileFormat).Load(_indexTableSize, indexTable);
_dataTableEx = DBCDatabaseLoader(table, format, index, _fileFormat).Load(_indexTableSize, indexTable);
}

View File

@@ -33,14 +33,14 @@ class TC_SHARED_API DBCStorageBase
char const* GetFormat() const { return _fileFormat; }
uint32 GetFieldCount() const { return _fieldCount; }
virtual bool Load(std::string const& path) = 0;
virtual bool LoadStringsFrom(std::string const& path) = 0;
virtual void LoadFromDB(std::string const& path, std::string const& dbFormat, std::string const& primaryKey) = 0;
virtual bool Load(char const* path) = 0;
virtual bool LoadStringsFrom(char const* path) = 0;
virtual void LoadFromDB(char const* table, char const* format, char const* index) = 0;
protected:
bool Load(std::string const& path, char**& indexTable);
bool LoadStringsFrom(std::string const& path, char** indexTable);
void LoadFromDB(std::string const& path, std::string const& dbFormat, std::string const& primaryKey, char**& indexTable);
bool Load(char const* path, char**& indexTable);
bool LoadStringsFrom(char const* path, char** indexTable);
void LoadFromDB(char const* table, char const* format, char const* index, char**& indexTable);
uint32 _fieldCount;
char const* _fileFormat;
@@ -71,19 +71,19 @@ class DBCStorage : public DBCStorageBase
uint32 GetNumRows() const { return _indexTableSize; }
bool Load(std::string const& path) override
bool Load(char const* path) override
{
return DBCStorageBase::Load(path, _indexTable.AsChar);
}
bool LoadStringsFrom(std::string const& path) override
bool LoadStringsFrom(char const* path) override
{
return DBCStorageBase::LoadStringsFrom(path, _indexTable.AsChar);
}
void LoadFromDB(std::string const& path, std::string const& dbFormat, std::string const& primaryKey) override
void LoadFromDB(char const* table, char const* format, char const* index) override
{
DBCStorageBase::LoadFromDB(path, dbFormat, primaryKey, _indexTable.AsChar);
DBCStorageBase::LoadFromDB(table, format, index, _indexTable.AsChar);
}
iterator begin() { return iterator(_indexTable.AsT, _indexTableSize); }

View File

@@ -19,7 +19,7 @@
#ifndef TRINITY_DBCSFRM_H
#define TRINITY_DBCSFRM_H
char constexpr Achievementfmt[] = "niixssssssssssssssssxxxxxxxxxxxxxxxxxxiixixxxxxxxxxxxxxxxxxxii";
char constexpr Achievementfmt[] = "niixssssssssssssssssxxxxxxxxxxxxxxxxxxiixixxxxxxxxxxxxxxxxxxii";
char constexpr CustomAchievementfmt[] = "pppaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaapapaaaaaaaaaaaaaaaaaapp";
char constexpr CustomAchievementIndex[] = "ID";
char constexpr AchievementCriteriafmt[] = "niiiiiiiixxxxxxxxxxxxxxxxxiiiix";
@@ -109,12 +109,12 @@ char constexpr SkillTiersfmt[] = "nxxxxxxxxxxxxxxxxiiiiiiiiiiiiiiii";
char constexpr SoundEntriesfmt[] = "nxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
char constexpr SpellCastTimefmt[] = "nixx";
char constexpr SpellCategoryfmt[] = "ni";
char constexpr SpellDifficultyfmt[] = "niiii";
char constexpr SpellDifficultyfmt[] = "niiii";
char constexpr CustomSpellDifficultyfmt[] = "ppppp";
char constexpr CustomSpellDifficultyIndex[] = "id";
char constexpr SpellDurationfmt[] = "niii";
char constexpr SpellEntryfmt[] = "niiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiifxiiiiiiiiiiiiiiiiiiiiiiiiiiiifffiiiiiiiiiiiiiiiiiiiiifffiiiiiiiiiiiiiiifffiiiiiiiiiiiiiissssssssssssssssxssssssssssssssssxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxiiiiiiiiiiixfffxxxiiiiixxfffxx";
char constexpr CustomSpellEntryfmt[] = "papppppppppppapapaaaaaaaaaaapaaapapppppppaaaaapaapaaaaaaaaaaaaaaaaaappppppppppppppppppppppppppppppppppppaaappppppppppppaaapppppppppaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaappppppppapppaaaaappaaaaaaa";
char constexpr SpellEntryfmt[] = "niiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiiifxiiiiiiiiiiiiiiiiiiiiiiiiiiiifffiiiiiiiiiiiiiiiiiiiiifffiiiiiiiiiiiiiiifffiiiiiiiiiiiiiissssssssssssssssxssssssssssssssssxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxiiiiiiiiiiixfffxxxiiiiixxfffxx";
char constexpr CustomSpellEntryfmt[] = "papppppppppppapapaaaaaaaaaaapaaapapppppppaaaaapaapaaaaaaaaaaaaaaaaaappppppppppppppppppppppppppppppppppppaaappppppppppppaaapppppppppaaaaapaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaappppppppapppaaaaappaaaaaaaa";
char constexpr CustomSpellEntryIndex[] = "Id";
char constexpr SpellFocusObjectfmt[] = "nxxxxxxxxxxxxxxxxx";
char constexpr SpellItemEnchantmentfmt[] = "nxiiiiiixxxiiissssssssssssssssxiiiiiii";