aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLadislav Zezula <zezula@volny.cz>2024-04-21 14:14:28 +0200
committerLadislav Zezula <zezula@volny.cz>2024-04-21 14:14:28 +0200
commit86b6888f3ece894d02ef625ff16939a14670ed98 (patch)
tree9e8d2cb2188cc540e9726edcbfea8c47543d5b96
parent605222393594f5885b877bfc0086dae756674965 (diff)
Fixed heap overrun in https://github.com/ladislav-zezula/StormLib/issues/327
-rw-r--r--src/SBaseFileTable.cpp91
-rwxr-xr-xtest/StormTest.cpp9
2 files changed, 64 insertions, 36 deletions
diff --git a/src/SBaseFileTable.cpp b/src/SBaseFileTable.cpp
index a222ac5..074501e 100644
--- a/src/SBaseFileTable.cpp
+++ b/src/SBaseFileTable.cpp
@@ -63,8 +63,8 @@ struct TMPQBits
{
static TMPQBits * Create(DWORD NumberOfBits, BYTE FillValue);
- void GetBits(unsigned int nBitPosition, unsigned int nBitLength, void * pvBuffer, int nResultSize);
- void SetBits(unsigned int nBitPosition, unsigned int nBitLength, void * pvBuffer, int nResultSize);
+ DWORD GetBits(unsigned int nBitPosition, unsigned int nBitLength, void * pvBuffer, int nResultSize);
+ DWORD SetBits(unsigned int nBitPosition, unsigned int nBitLength, void * pvBuffer, int nResultSize);
static const USHORT SetBitsMask[]; // Bit mask for each number of bits (0-8)
@@ -94,7 +94,7 @@ TMPQBits * TMPQBits::Create(
return pBitArray;
}
-void TMPQBits::GetBits(
+DWORD TMPQBits::GetBits(
unsigned int nBitPosition,
unsigned int nBitLength,
void * pvBuffer,
@@ -110,6 +110,12 @@ void TMPQBits::GetBits(
// Keep compilers happy for platforms where nResultByteSize is not used
STORMLIB_UNUSED(nResultByteSize);
+ // Check for bit overflow
+ if(nBitPosition + nBitLength < nBitPosition)
+ return ERROR_BUFFER_OVERFLOW;
+ if(nBitPosition + nBitLength > NumberOfBits)
+ return ERROR_BUFFER_OVERFLOW;
+
#ifdef _DEBUG
// Check if the target is properly zeroed
for(int i = 0; i < nResultByteSize; i++)
@@ -157,9 +163,10 @@ void TMPQBits::GetBits(
*pbBuffer &= (0x01 << nBitLength) - 1;
}
+ return ERROR_SUCCESS;
}
-void TMPQBits::SetBits(
+DWORD TMPQBits::SetBits(
unsigned int nBitPosition,
unsigned int nBitLength,
void * pvBuffer,
@@ -175,6 +182,12 @@ void TMPQBits::SetBits(
// Keep compilers happy for platforms where nResultByteSize is not used
STORMLIB_UNUSED(nResultByteSize);
+ // Check for bit overflow
+ if(nBitPosition + nBitLength < nBitPosition)
+ return ERROR_BUFFER_OVERFLOW;
+ if(nBitPosition + nBitLength > NumberOfBits)
+ return ERROR_BUFFER_OVERFLOW;
+
#ifndef STORMLIB_LITTLE_ENDIAN
// Adjust the buffer pointer for big endian platforms
pbBuffer += (nResultByteSize - 1);
@@ -223,6 +236,7 @@ void TMPQBits::SetBits(
Elements[nBytePosition] = (BYTE)((Elements[nBytePosition] & ~AndMask) | BitBuffer);
}
}
+ return ERROR_SUCCESS;
}
void GetMPQBits(TMPQBits * pBits, unsigned int nBitPosition, unsigned int nBitLength, void * pvBuffer, int nResultByteSize)
@@ -2599,7 +2613,7 @@ static DWORD BuildFileTable_HetBet(TMPQArchive * ha)
TMPQBits * pBitArray;
DWORD dwBitPosition = 0;
DWORD i;
- DWORD dwErrCode = ERROR_FILE_CORRUPT;
+ DWORD dwErrCode = ERROR_SUCCESS;
// Load the BET table from the MPQ
pBetTable = LoadBetTable(ha);
@@ -2622,10 +2636,16 @@ static DWORD BuildFileTable_HetBet(TMPQArchive * ha)
if(pHetTable->pNameHashes[i] != HET_ENTRY_FREE)
{
// Load the index to the BET table
- pHetTable->pBetIndexes->GetBits(pHetTable->dwIndexSizeTotal * i,
- pHetTable->dwIndexSize,
- &dwFileIndex,
- 4);
+ dwErrCode = pHetTable->pBetIndexes->GetBits(pHetTable->dwIndexSizeTotal * i,
+ pHetTable->dwIndexSize,
+ &dwFileIndex,
+ 4);
+ if(dwErrCode != ERROR_SUCCESS)
+ {
+ FreeBetTable(pBetTable);
+ return ERROR_FILE_CORRUPT;
+ }
+
// Overflow test
if(dwFileIndex < pBetTable->dwEntryCount)
{
@@ -2633,10 +2653,15 @@ static DWORD BuildFileTable_HetBet(TMPQArchive * ha)
ULONGLONG NameHash2 = 0;
// Load the BET hash
- pBetTable->pNameHashes->GetBits(pBetTable->dwBitTotal_NameHash2 * dwFileIndex,
- pBetTable->dwBitCount_NameHash2,
- &NameHash2,
- 8);
+ dwErrCode = pBetTable->pNameHashes->GetBits(pBetTable->dwBitTotal_NameHash2 * dwFileIndex,
+ pBetTable->dwBitCount_NameHash2,
+ &NameHash2,
+ 8);
+ if(dwErrCode != ERROR_SUCCESS)
+ {
+ FreeBetTable(pBetTable);
+ return ERROR_FILE_CORRUPT;
+ }
// Combine both part of the name hash and put it to the file table
pFileEntry = ha->pFileTable + dwFileIndex;
@@ -2653,31 +2678,35 @@ static DWORD BuildFileTable_HetBet(TMPQArchive * ha)
DWORD dwFlagIndex = 0;
// Read the file position
- pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_FilePos,
- pBetTable->dwBitCount_FilePos,
- &pFileEntry->ByteOffset,
- 8);
+ if((dwErrCode = pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_FilePos,
+ pBetTable->dwBitCount_FilePos,
+ &pFileEntry->ByteOffset,
+ 8)) != ERROR_SUCCESS)
+ break;
// Read the file size
- pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_FileSize,
- pBetTable->dwBitCount_FileSize,
- &pFileEntry->dwFileSize,
- 4);
+ if((dwErrCode = pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_FileSize,
+ pBetTable->dwBitCount_FileSize,
+ &pFileEntry->dwFileSize,
+ 4)) != ERROR_SUCCESS)
+ break;
// Read the compressed size
- pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_CmpSize,
- pBetTable->dwBitCount_CmpSize,
- &pFileEntry->dwCmpSize,
- 4);
-
+ if((dwErrCode = pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_CmpSize,
+ pBetTable->dwBitCount_CmpSize,
+ &pFileEntry->dwCmpSize,
+ 4)) != ERROR_SUCCESS)
+ break;
// Read the flag index
if(pBetTable->dwFlagCount != 0)
{
- pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_FlagIndex,
- pBetTable->dwBitCount_FlagIndex,
- &dwFlagIndex,
- 4);
+ if((dwErrCode = pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_FlagIndex,
+ pBetTable->dwBitCount_FlagIndex,
+ &dwFlagIndex,
+ 4)) != ERROR_SUCCESS)
+ break;
+
pFileEntry->dwFlags = pBetTable->pFileFlags[dwFlagIndex];
}
@@ -2692,13 +2721,11 @@ static DWORD BuildFileTable_HetBet(TMPQArchive * ha)
// Set the current size of the file table
FreeBetTable(pBetTable);
- dwErrCode = ERROR_SUCCESS;
}
else
{
dwErrCode = ERROR_FILE_CORRUPT;
}
-
return dwErrCode;
}
diff --git a/test/StormTest.cpp b/test/StormTest.cpp
index 3276c59..718c910 100755
--- a/test/StormTest.cpp
+++ b/test/StormTest.cpp
@@ -4183,9 +4183,10 @@ static const LPCSTR Test_CreateMpq_Localized[] =
// Main
#define TEST_COMMAND_LINE
-#define TEST_LOCAL_LISTFILE
-#define TEST_STREAM_OPERATIONS
-#define TEST_MASTER_MIRROR
+//#define TEST_LOCAL_LISTFILE
+//#define TEST_STREAM_OPERATIONS
+//#define TEST_MASTER_MIRROR
+#define TEST_OPEN_MPQ
#define TEST_OPEN_MPQ
#define TEST_REOPEN_MPQ
#define TEST_VERIFY_SIGNATURE
@@ -4211,7 +4212,7 @@ int _tmain(int argc, TCHAR * argv[])
#ifdef TEST_COMMAND_LINE
// Test-open MPQs from the command line. They must be plain name
// and must be placed in the Test-MPQs folder
- for(int i = 1; i < argc; i++)
+ for(int i = 2; i < argc; i++)
{
TestOpenArchive(argv[i], NULL, NULL, 0, &LfBliz);
}