aboutsummaryrefslogtreecommitdiff
path: root/src/server/shared
diff options
context:
space:
mode:
authorShauren <shauren.trinity@gmail.com>2014-08-07 19:02:08 +0200
committerShauren <shauren.trinity@gmail.com>2014-08-10 11:00:27 +0200
commitdf11916ad53e6b2f64cd1af5d5296ba188f3e486 (patch)
tree0ece1cfd1133e8ecedede59956e3de1578798807 /src/server/shared
parent91053d557ca89c4b0c455366afae258835bd25f8 (diff)
Core/NetworkIO: Allow receiving packets bigger than buffer size and properly handle situations where not entire packet was read in one go
Core/Authserver: Restored authenticator functionality
Diffstat (limited to 'src/server/shared')
-rw-r--r--src/server/shared/Networking/MessageBuffer.h93
-rw-r--r--src/server/shared/Networking/Socket.h119
-rw-r--r--src/server/shared/Packets/ByteBuffer.cpp5
-rw-r--r--src/server/shared/Packets/ByteBuffer.h12
-rw-r--r--src/server/shared/Packets/WorldPacket.h2
5 files changed, 201 insertions, 30 deletions
diff --git a/src/server/shared/Networking/MessageBuffer.h b/src/server/shared/Networking/MessageBuffer.h
new file mode 100644
index 00000000000..fff94b86c1e
--- /dev/null
+++ b/src/server/shared/Networking/MessageBuffer.h
@@ -0,0 +1,93 @@
+/*
+* Copyright (C) 2008-2014 TrinityCore <http://www.trinitycore.org/>
+*
+* This program is free software; you can redistribute it and/or modify it
+* under the terms of the GNU General Public License as published by the
+* Free Software Foundation; either version 2 of the License, or (at your
+* option) any later version.
+*
+* This program is distributed in the hope that it will be useful, but WITHOUT
+* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+* more details.
+*
+* You should have received a copy of the GNU General Public License along
+* with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+#ifndef __MESSAGEBUFFER_H_
+#define __MESSAGEBUFFER_H_
+
+#include "Define.h"
+#include <vector>
+
+class MessageBuffer
+{
+ typedef std::vector<uint8>::size_type size_type;
+
+public:
+ MessageBuffer() : _wpos(0), _storage() { }
+
+ MessageBuffer(MessageBuffer const& right) : _wpos(right._wpos), _storage(right._storage) { }
+
+ MessageBuffer(MessageBuffer&& right) : _wpos(right._wpos), _storage(right.Move()) { }
+
+ void Reset()
+ {
+ _storage.clear();
+ _wpos = 0;
+ }
+
+ bool IsMessageReady() const { return _wpos == _storage.size(); }
+
+ size_type GetMissingSize() const { return _storage.size() - _wpos; }
+
+ uint8* Data() { return _storage.data(); }
+
+ void Grow(size_type bytes)
+ {
+ _storage.resize(_storage.size() + bytes);
+ }
+
+ uint8* GetWritePointer() { return &_storage[_wpos]; }
+
+ void WriteCompleted(size_type bytes) { _wpos += bytes; }
+
+ void ResetWritePointer() { _wpos = 0; }
+
+ size_type GetSize() { return _storage.size(); }
+
+ std::vector<uint8>&& Move()
+ {
+ _wpos = 0;
+ return std::move(_storage);
+ }
+
+ MessageBuffer& operator=(MessageBuffer& right)
+ {
+ if (this != &right)
+ {
+ _wpos = right._wpos;
+ _storage = right._storage;
+ }
+
+ return *this;
+ }
+
+ MessageBuffer& operator=(MessageBuffer&& right)
+ {
+ if (this != &right)
+ {
+ _wpos = right._wpos;
+ _storage = right.Move();
+ }
+
+ return *this;
+ }
+
+private:
+ size_type _wpos;
+ std::vector<uint8> _storage;
+};
+
+#endif /* __MESSAGEBUFFER_H_ */
diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h
index 6bf67e06d9c..4a3f2990799 100644
--- a/src/server/shared/Networking/Socket.h
+++ b/src/server/shared/Networking/Socket.h
@@ -18,7 +18,7 @@
#ifndef __SOCKET_H__
#define __SOCKET_H__
-#include "Define.h"
+#include "MessageBuffer.h"
#include "Log.h"
#include <vector>
#include <mutex>
@@ -28,19 +28,23 @@
#include <type_traits>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/write.hpp>
+#include <boost/asio/read.hpp>
using boost::asio::ip::tcp;
+#define READ_BLOCK_SIZE 4096
+
template<class T, class PacketType>
class Socket : public std::enable_shared_from_this<T>
{
typedef typename std::conditional<std::is_pointer<PacketType>::value, PacketType, PacketType const&>::type WritePacketType;
public:
- Socket(tcp::socket&& socket, std::size_t headerSize) : _socket(std::move(socket)), _headerSize(headerSize)
+ Socket(tcp::socket&& socket, std::size_t headerSize) : _socket(std::move(socket))
{
- _remotePort = _socket.remote_endpoint().port();
_remoteAddress = _socket.remote_endpoint().address();
+ _remotePort = _socket.remote_endpoint().port();
+ _readHeaderBuffer.Grow(headerSize);
}
virtual void Start() = 0;
@@ -57,25 +61,39 @@ public:
void AsyncReadHeader()
{
- _socket.async_read_some(boost::asio::buffer(_readBuffer, _headerSize), std::bind(&Socket<T, PacketType>::ReadHeaderHandlerInternal, this->shared_from_this(),
- std::placeholders::_1, std::placeholders::_2));
+ _readHeaderBuffer.ResetWritePointer();
+ _readDataBuffer.Reset();
+
+ AsyncReadMissingHeaderData();
}
- void AsyncReadData(std::size_t size, std::size_t bufferOffset)
+ void AsyncReadData(std::size_t size)
{
- _socket.async_read_some(boost::asio::buffer(&_readBuffer[bufferOffset], size), std::bind(&Socket<T, PacketType>::ReadDataHandlerInternal, this->shared_from_this(),
- std::placeholders::_1, std::placeholders::_2));
+ if (!size)
+ {
+ // if this is a packet with 0 length body just invoke handler directly
+ ReadDataHandler();
+ return;
+ }
+
+ _readDataBuffer.Grow(size);
+ AsyncReadMissingData();
}
- void ReadData(std::size_t size, std::size_t bufferOffset)
+ void ReadData(std::size_t size)
{
boost::system::error_code error;
- _socket.read_some(boost::asio::buffer(&_readBuffer[bufferOffset], size), error);
+ _readDataBuffer.Grow(size);
- if (error)
+ std::size_t bytesRead = boost::asio::read(_socket, boost::asio::buffer(_readDataBuffer.GetWritePointer(), size), error);
+
+ _readDataBuffer.WriteCompleted(bytesRead);
+
+ if (error || !_readDataBuffer.IsMessageReady())
{
- TC_LOG_DEBUG("network", "Socket::ReadData: %s errored with: %i (%s)", GetRemoteIpAddress().to_string().c_str(), error.value(), error.message().c_str());
+ TC_LOG_DEBUG("network", "Socket::ReadData: %s errored with: %i (%s)", GetRemoteIpAddress().to_string().c_str(), error.value(),
+ error.message().c_str());
CloseSocket();
}
@@ -83,8 +101,8 @@ public:
void AsyncWrite(WritePacketType data)
{
- boost::asio::async_write(_socket, boost::asio::buffer(data), std::bind(&Socket<T, PacketType>::WriteHandler, this->shared_from_this(), std::placeholders::_1,
- std::placeholders::_2));
+ boost::asio::async_write(_socket, boost::asio::buffer(data), std::bind(&Socket<T, PacketType>::WriteHandler, this->shared_from_this(),
+ std::placeholders::_1, std::placeholders::_2));
}
bool IsOpen() const { return _socket.is_open(); }
@@ -94,7 +112,7 @@ public:
_socket.shutdown(boost::asio::socket_base::shutdown_both, shutdownError);
if (shutdownError)
TC_LOG_DEBUG("network", "Socket::CloseSocket: %s errored when shutting down socket: %i (%s)", GetRemoteIpAddress().to_string().c_str(),
- shutdownError.value(), shutdownError.message().c_str());
+ shutdownError.value(), shutdownError.message().c_str());
boost::system::error_code error;
_socket.close(error);
@@ -103,18 +121,72 @@ public:
error.value(), error.message().c_str());
}
- uint8* GetReadBuffer() { return _readBuffer; }
+ virtual bool IsHeaderReady() const { return _readHeaderBuffer.IsMessageReady(); }
+ virtual bool IsDataReady() const { return _readDataBuffer.IsMessageReady(); }
+
+ uint8* GetHeaderBuffer() { return _readHeaderBuffer.Data(); }
+ uint8* GetDataBuffer() { return _readDataBuffer.Data(); }
+
+ MessageBuffer&& MoveHeader() { return std::move(_readHeaderBuffer); }
+ MessageBuffer&& MoveData() { return std::move(_readDataBuffer); }
protected:
- virtual void ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) = 0;
- virtual void ReadDataHandler(boost::system::error_code error, size_t transferedBytes) = 0;
+ virtual void ReadHeaderHandler() = 0;
+ virtual void ReadDataHandler() = 0;
std::mutex _writeLock;
std::queue<PacketType> _writeQueue;
private:
- void ReadHeaderHandlerInternal(boost::system::error_code error, size_t transferedBytes) { ReadHeaderHandler(error, transferedBytes); }
- void ReadDataHandlerInternal(boost::system::error_code error, size_t transferedBytes) { ReadDataHandler(error, transferedBytes); }
+ void AsyncReadMissingHeaderData()
+ {
+ _socket.async_read_some(boost::asio::buffer(_readHeaderBuffer.GetWritePointer(), std::min<std::size_t>(READ_BLOCK_SIZE, _readHeaderBuffer.GetMissingSize())),
+ std::bind(&Socket<T, PacketType>::ReadHeaderHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
+ }
+
+ void AsyncReadMissingData()
+ {
+ _socket.async_read_some(boost::asio::buffer(_readDataBuffer.GetWritePointer(), std::min<std::size_t>(READ_BLOCK_SIZE, _readDataBuffer.GetMissingSize())),
+ std::bind(&Socket<T, PacketType>::ReadDataHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
+ }
+
+ void ReadHeaderHandlerInternal(boost::system::error_code error, size_t transferredBytes)
+ {
+ if (error)
+ {
+ CloseSocket();
+ return;
+ }
+
+ _readHeaderBuffer.WriteCompleted(transferredBytes);
+ if (!IsHeaderReady())
+ {
+ // incomplete, read more
+ AsyncReadMissingHeaderData();
+ return;
+ }
+
+ ReadHeaderHandler();
+ }
+
+ void ReadDataHandlerInternal(boost::system::error_code error, size_t transferredBytes)
+ {
+ if (error)
+ {
+ CloseSocket();
+ return;
+ }
+
+ _readDataBuffer.WriteCompleted(transferredBytes);
+ if (!IsDataReady())
+ {
+ // incomplete, read more
+ AsyncReadMissingData();
+ return;
+ }
+
+ ReadDataHandler();
+ }
void WriteHandler(boost::system::error_code error, size_t /*transferedBytes*/)
{
@@ -140,12 +212,11 @@ private:
tcp::socket _socket;
- uint8 _readBuffer[4096];
-
- uint16 _remotePort;
boost::asio::ip::address _remoteAddress;
+ uint16 _remotePort;
- std::size_t _headerSize;
+ MessageBuffer _readHeaderBuffer;
+ MessageBuffer _readDataBuffer;
};
#endif // __SOCKET_H__
diff --git a/src/server/shared/Packets/ByteBuffer.cpp b/src/server/shared/Packets/ByteBuffer.cpp
index 86234039a4a..3785d1c29fa 100644
--- a/src/server/shared/Packets/ByteBuffer.cpp
+++ b/src/server/shared/Packets/ByteBuffer.cpp
@@ -17,11 +17,16 @@
*/
#include "ByteBuffer.h"
+#include "MessageBuffer.h"
#include "Common.h"
#include "Log.h"
#include <sstream>
+ByteBuffer::ByteBuffer(MessageBuffer&& buffer) : _rpos(0), _wpos(0), _storage(buffer.Move())
+{
+}
+
ByteBufferPositionException::ByteBufferPositionException(bool add, size_t pos,
size_t size, size_t valueSize)
{
diff --git a/src/server/shared/Packets/ByteBuffer.h b/src/server/shared/Packets/ByteBuffer.h
index c678e9dce06..456223d744d 100644
--- a/src/server/shared/Packets/ByteBuffer.h
+++ b/src/server/shared/Packets/ByteBuffer.h
@@ -34,6 +34,8 @@
#include <math.h>
#include <boost/asio/buffer.hpp>
+class MessageBuffer;
+
// Root of ByteBuffer exception hierarchy
class ByteBufferException : public std::exception
{
@@ -82,14 +84,12 @@ class ByteBuffer
}
ByteBuffer(ByteBuffer&& buf) : _rpos(buf._rpos), _wpos(buf._wpos),
- _storage(std::move(buf._storage))
- {
- }
+ _storage(std::move(buf._storage)) { }
ByteBuffer(ByteBuffer const& right) : _rpos(right._rpos), _wpos(right._wpos),
- _storage(right._storage)
- {
- }
+ _storage(right._storage) { }
+
+ ByteBuffer(MessageBuffer&& buffer);
ByteBuffer& operator=(ByteBuffer const& right)
{
diff --git a/src/server/shared/Packets/WorldPacket.h b/src/server/shared/Packets/WorldPacket.h
index 8851b9f3e45..848a00739fe 100644
--- a/src/server/shared/Packets/WorldPacket.h
+++ b/src/server/shared/Packets/WorldPacket.h
@@ -51,6 +51,8 @@ class WorldPacket : public ByteBuffer
return *this;
}
+ WorldPacket(uint16 opcode, MessageBuffer&& buffer) : ByteBuffer(std::move(buffer)), m_opcode(opcode) { }
+
void Initialize(uint16 opcode, size_t newres=200)
{
clear();