From 86b6888f3ece894d02ef625ff16939a14670ed98 Mon Sep 17 00:00:00 2001 From: Ladislav Zezula Date: Sun, 21 Apr 2024 14:14:28 +0200 Subject: Fixed heap overrun in https://github.com/ladislav-zezula/StormLib/issues/327 --- src/SBaseFileTable.cpp | 91 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 32 deletions(-) (limited to 'src') diff --git a/src/SBaseFileTable.cpp b/src/SBaseFileTable.cpp index a222ac5..074501e 100644 --- a/src/SBaseFileTable.cpp +++ b/src/SBaseFileTable.cpp @@ -63,8 +63,8 @@ struct TMPQBits { static TMPQBits * Create(DWORD NumberOfBits, BYTE FillValue); - void GetBits(unsigned int nBitPosition, unsigned int nBitLength, void * pvBuffer, int nResultSize); - void SetBits(unsigned int nBitPosition, unsigned int nBitLength, void * pvBuffer, int nResultSize); + DWORD GetBits(unsigned int nBitPosition, unsigned int nBitLength, void * pvBuffer, int nResultSize); + DWORD SetBits(unsigned int nBitPosition, unsigned int nBitLength, void * pvBuffer, int nResultSize); static const USHORT SetBitsMask[]; // Bit mask for each number of bits (0-8) @@ -94,7 +94,7 @@ TMPQBits * TMPQBits::Create( return pBitArray; } -void TMPQBits::GetBits( +DWORD TMPQBits::GetBits( unsigned int nBitPosition, unsigned int nBitLength, void * pvBuffer, @@ -110,6 +110,12 @@ void TMPQBits::GetBits( // Keep compilers happy for platforms where nResultByteSize is not used STORMLIB_UNUSED(nResultByteSize); + // Check for bit overflow + if(nBitPosition + nBitLength < nBitPosition) + return ERROR_BUFFER_OVERFLOW; + if(nBitPosition + nBitLength > NumberOfBits) + return ERROR_BUFFER_OVERFLOW; + #ifdef _DEBUG // Check if the target is properly zeroed for(int i = 0; i < nResultByteSize; i++) @@ -157,9 +163,10 @@ void TMPQBits::GetBits( *pbBuffer &= (0x01 << nBitLength) - 1; } + return ERROR_SUCCESS; } -void TMPQBits::SetBits( +DWORD TMPQBits::SetBits( unsigned int nBitPosition, unsigned int nBitLength, void * pvBuffer, @@ -175,6 +182,12 @@ void TMPQBits::SetBits( // Keep compilers happy for platforms where nResultByteSize is not used STORMLIB_UNUSED(nResultByteSize); + // Check for bit overflow + if(nBitPosition + nBitLength < nBitPosition) + return ERROR_BUFFER_OVERFLOW; + if(nBitPosition + nBitLength > NumberOfBits) + return ERROR_BUFFER_OVERFLOW; + #ifndef STORMLIB_LITTLE_ENDIAN // Adjust the buffer pointer for big endian platforms pbBuffer += (nResultByteSize - 1); @@ -223,6 +236,7 @@ void TMPQBits::SetBits( Elements[nBytePosition] = (BYTE)((Elements[nBytePosition] & ~AndMask) | BitBuffer); } } + return ERROR_SUCCESS; } void GetMPQBits(TMPQBits * pBits, unsigned int nBitPosition, unsigned int nBitLength, void * pvBuffer, int nResultByteSize) @@ -2599,7 +2613,7 @@ static DWORD BuildFileTable_HetBet(TMPQArchive * ha) TMPQBits * pBitArray; DWORD dwBitPosition = 0; DWORD i; - DWORD dwErrCode = ERROR_FILE_CORRUPT; + DWORD dwErrCode = ERROR_SUCCESS; // Load the BET table from the MPQ pBetTable = LoadBetTable(ha); @@ -2622,10 +2636,16 @@ static DWORD BuildFileTable_HetBet(TMPQArchive * ha) if(pHetTable->pNameHashes[i] != HET_ENTRY_FREE) { // Load the index to the BET table - pHetTable->pBetIndexes->GetBits(pHetTable->dwIndexSizeTotal * i, - pHetTable->dwIndexSize, - &dwFileIndex, - 4); + dwErrCode = pHetTable->pBetIndexes->GetBits(pHetTable->dwIndexSizeTotal * i, + pHetTable->dwIndexSize, + &dwFileIndex, + 4); + if(dwErrCode != ERROR_SUCCESS) + { + FreeBetTable(pBetTable); + return ERROR_FILE_CORRUPT; + } + // Overflow test if(dwFileIndex < pBetTable->dwEntryCount) { @@ -2633,10 +2653,15 @@ static DWORD BuildFileTable_HetBet(TMPQArchive * ha) ULONGLONG NameHash2 = 0; // Load the BET hash - pBetTable->pNameHashes->GetBits(pBetTable->dwBitTotal_NameHash2 * dwFileIndex, - pBetTable->dwBitCount_NameHash2, - &NameHash2, - 8); + dwErrCode = pBetTable->pNameHashes->GetBits(pBetTable->dwBitTotal_NameHash2 * dwFileIndex, + pBetTable->dwBitCount_NameHash2, + &NameHash2, + 8); + if(dwErrCode != ERROR_SUCCESS) + { + FreeBetTable(pBetTable); + return ERROR_FILE_CORRUPT; + } // Combine both part of the name hash and put it to the file table pFileEntry = ha->pFileTable + dwFileIndex; @@ -2653,31 +2678,35 @@ static DWORD BuildFileTable_HetBet(TMPQArchive * ha) DWORD dwFlagIndex = 0; // Read the file position - pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_FilePos, - pBetTable->dwBitCount_FilePos, - &pFileEntry->ByteOffset, - 8); + if((dwErrCode = pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_FilePos, + pBetTable->dwBitCount_FilePos, + &pFileEntry->ByteOffset, + 8)) != ERROR_SUCCESS) + break; // Read the file size - pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_FileSize, - pBetTable->dwBitCount_FileSize, - &pFileEntry->dwFileSize, - 4); + if((dwErrCode = pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_FileSize, + pBetTable->dwBitCount_FileSize, + &pFileEntry->dwFileSize, + 4)) != ERROR_SUCCESS) + break; // Read the compressed size - pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_CmpSize, - pBetTable->dwBitCount_CmpSize, - &pFileEntry->dwCmpSize, - 4); - + if((dwErrCode = pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_CmpSize, + pBetTable->dwBitCount_CmpSize, + &pFileEntry->dwCmpSize, + 4)) != ERROR_SUCCESS) + break; // Read the flag index if(pBetTable->dwFlagCount != 0) { - pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_FlagIndex, - pBetTable->dwBitCount_FlagIndex, - &dwFlagIndex, - 4); + if((dwErrCode = pBitArray->GetBits(dwBitPosition + pBetTable->dwBitIndex_FlagIndex, + pBetTable->dwBitCount_FlagIndex, + &dwFlagIndex, + 4)) != ERROR_SUCCESS) + break; + pFileEntry->dwFlags = pBetTable->pFileFlags[dwFlagIndex]; } @@ -2692,13 +2721,11 @@ static DWORD BuildFileTable_HetBet(TMPQArchive * ha) // Set the current size of the file table FreeBetTable(pBetTable); - dwErrCode = ERROR_SUCCESS; } else { dwErrCode = ERROR_FILE_CORRUPT; } - return dwErrCode; } -- cgit v1.2.3