diff options
Diffstat (limited to 'src/server/authserver/Server/BattlenetBitStream.h')
-rw-r--r-- | src/server/authserver/Server/BattlenetBitStream.h | 29 |
1 files changed, 13 insertions, 16 deletions
diff --git a/src/server/authserver/Server/BattlenetBitStream.h b/src/server/authserver/Server/BattlenetBitStream.h index c1c95236360..82c2a0a6d5d 100644 --- a/src/server/authserver/Server/BattlenetBitStream.h +++ b/src/server/authserver/Server/BattlenetBitStream.h @@ -59,16 +59,16 @@ namespace Battlenet static uint32 const MaxSize = 0x4000; // length : The maximum number of bytes to read - BitStream(uint32 length) : _numBits(length * 8), _readPos(0), _writePos(0) + BitStream(uint32 length) : _writePos(length * 8), _readPos(0) { _buffer.resize(length, 0); } - BitStream(MessageBuffer&& buffer) : _buffer(buffer.Move()), _numBits(_buffer.size() * 8), _readPos(0), _writePos(0) + BitStream(MessageBuffer&& buffer) : _writePos(buffer.GetReadyDataSize() * 8), _readPos(0), _buffer(buffer.Move()) { } - BitStream() : _numBits(0), _readPos(0), _writePos(0) + BitStream() : _writePos(0), _readPos(0) { _buffer.reserve(0x1000); } @@ -91,8 +91,8 @@ namespace Battlenet std::unique_ptr<uint8[]> ReadBytes(uint32 count) { AlignToNextByte(); - if (_readPos + count * 8 > _numBits) - throw BitStreamPositionException(true, count * 8, _readPos, _numBits); + if (_readPos + count * 8 > _writePos) + throw BitStreamPositionException(true, count * 8, _readPos, _writePos); std::unique_ptr<uint8[]> buf(new uint8[count]); memcpy(buf.get(), &_buffer[_readPos >> 3], count); @@ -125,8 +125,8 @@ namespace Battlenet { static_assert(std::is_integral<T>::value || std::is_enum<T>::value, "T must be an integer type"); - if (_readPos + bitCount >= _numBits) - throw BitStreamPositionException(true, bitCount, _readPos, _numBits); + if (_readPos + bitCount >= _writePos) + throw BitStreamPositionException(true, bitCount, _readPos, _writePos); uint64 ret = 0; while (bitCount != 0) @@ -214,26 +214,23 @@ namespace Battlenet void SetReadPos(uint32 bits) { - if (bits >= _numBits) - throw BitStreamPositionException(true, bits, 0, _numBits); + if (bits >= _writePos) + throw BitStreamPositionException(true, bits, 0, _writePos); _readPos = bits; } - bool IsRead() const { return _readPos >= _numBits; } + bool IsRead() const { return _readPos >= _writePos; } uint8* GetBuffer() { return _buffer.data(); } uint8 const* GetBuffer() const { return _buffer.data(); } - size_t GetSize() const { return _buffer.size(); } - - void FinishReading() { _readPos = _numBits; } + size_t GetSize() const { return ((_writePos + 7) & ~7) / 8; } private: - std::vector<uint8> _buffer; - uint32 _numBits; - uint32 _readPos; uint32 _writePos; + uint32 _readPos; + std::vector<uint8> _buffer; }; template<> |