aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/server/authserver/Server/AuthSession.cpp81
-rw-r--r--src/server/authserver/Server/AuthSession.h4
-rw-r--r--src/server/game/Server/WorldSocket.cpp109
-rw-r--r--src/server/game/Server/WorldSocket.h4
-rw-r--r--src/server/shared/Common.h40
-rw-r--r--src/server/shared/Database/DatabaseWorkerPool.h95
-rw-r--r--src/server/shared/Database/MySQLConnection.cpp8
-rw-r--r--src/server/shared/Networking/MessageBuffer.h93
-rw-r--r--src/server/shared/Networking/Socket.h129
-rw-r--r--src/server/shared/Packets/ByteBuffer.cpp5
-rw-r--r--src/server/shared/Packets/ByteBuffer.h12
-rw-r--r--src/server/shared/Packets/WorldPacket.h2
-rw-r--r--src/server/worldserver/Main.cpp57
-rw-r--r--src/tools/mmaps_generator/PathGenerator.cpp2
14 files changed, 374 insertions, 267 deletions
diff --git a/src/server/authserver/Server/AuthSession.cpp b/src/server/authserver/Server/AuthSession.cpp
index cdd8298174f..76f8b8c27b0 100644
--- a/src/server/authserver/Server/AuthSession.cpp
+++ b/src/server/authserver/Server/AuthSession.cpp
@@ -21,6 +21,7 @@
#include "AuthCodes.h"
#include "Database/DatabaseEnv.h"
#include "SHA1.h"
+#include "TOTP.h"
#include "openssl/crypto.h"
#include "Configuration/Config.h"
#include "RealmList.h"
@@ -52,7 +53,6 @@ enum eStatus
typedef struct AUTH_LOGON_CHALLENGE_C
{
- uint8 cmd;
uint8 error;
uint16 size;
uint8 gamename[4];
@@ -71,7 +71,6 @@ typedef struct AUTH_LOGON_CHALLENGE_C
typedef struct AUTH_LOGON_PROOF_C
{
- uint8 cmd;
uint8 A[32];
uint8 M1[20];
uint8 crc_hash[20];
@@ -99,7 +98,6 @@ typedef struct AUTH_LOGON_PROOF_S_OLD
typedef struct AUTH_RECONNECT_PROOF_C
{
- uint8 cmd;
uint8 R1[16];
uint8 R2[20];
uint8 R3[20];
@@ -114,10 +112,10 @@ enum class BufferSizes : uint32
SRP_6_S = 0x20,
};
-#define REALM_LIST_PACKET_SIZE 5
-#define XFER_ACCEPT_SIZE 1
-#define XFER_RESUME_SIZE 9
-#define XFER_CANCEL_SIZE 1
+#define REALM_LIST_PACKET_SIZE 4
+#define XFER_ACCEPT_SIZE 0
+#define XFER_RESUME_SIZE 8
+#define XFER_CANCEL_SIZE 0
std::unordered_map<uint8, AuthHandler> AuthSession::InitHandlers()
{
@@ -137,44 +135,36 @@ std::unordered_map<uint8, AuthHandler> AuthSession::InitHandlers()
std::unordered_map<uint8, AuthHandler> const Handlers = AuthSession::InitHandlers();
-void AuthSession::ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes)
+void AuthSession::ReadHeaderHandler()
{
- if (!error && transferedBytes == 1)
+ uint8 cmd = GetHeaderBuffer()[0];
+ auto itr = Handlers.find(cmd);
+ if (itr != Handlers.end())
{
- uint8 cmd = GetReadBuffer()[0];
- auto itr = Handlers.find(cmd);
- if (itr != Handlers.end())
+ // Handle dynamic size packet
+ if (cmd == AUTH_LOGON_CHALLENGE || cmd == AUTH_RECONNECT_CHALLENGE)
{
- // Handle dynamic size packet
- if (cmd == AUTH_LOGON_CHALLENGE || cmd == AUTH_RECONNECT_CHALLENGE)
- {
- ReadData(sizeof(uint8) + sizeof(uint16), sizeof(cmd)); //error + size
- sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer());
+ ReadData(sizeof(uint8) + sizeof(uint16)); //error + size
+ sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetDataBuffer());
- AsyncReadData(challenge->size, sizeof(uint8) + sizeof(uint8) + sizeof(uint16)); // cmd + error + size
- }
- else
- AsyncReadData(itr->second.packetSize, sizeof(uint8));
+ AsyncReadData(challenge->size);
}
+ else
+ AsyncReadData(itr->second.packetSize);
}
else
CloseSocket();
}
-void AuthSession::ReadDataHandler(boost::system::error_code error, size_t transferedBytes)
+void AuthSession::ReadDataHandler()
{
- if (!error && transferedBytes > 0)
+ if (!(*this.*Handlers.at(GetHeaderBuffer()[0]).handler)())
{
- if (!(*this.*Handlers.at(GetReadBuffer()[0]).handler)())
- {
- CloseSocket();
- return;
- }
-
- AsyncReadHeader();
- }
- else
CloseSocket();
+ return;
+ }
+
+ AsyncReadHeader();
}
void AuthSession::AsyncWrite(ByteBuffer& packet)
@@ -191,7 +181,7 @@ void AuthSession::AsyncWrite(ByteBuffer& packet)
bool AuthSession::HandleLogonChallenge()
{
- sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer());
+ sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetDataBuffer());
//TC_LOG_DEBUG("server.authserver", "[AuthChallenge] got full packet, %#04x bytes", challenge->size);
TC_LOG_DEBUG("server.authserver", "[AuthChallenge] name(%d): '%s'", challenge->I_len, challenge->I);
@@ -410,7 +400,7 @@ bool AuthSession::HandleLogonProof()
TC_LOG_DEBUG("server.authserver", "Entering _HandleLogonProof");
// Read the packet
- sAuthLogonProof_C *logonProof = reinterpret_cast<sAuthLogonProof_C*>(GetReadBuffer());
+ sAuthLogonProof_C *logonProof = reinterpret_cast<sAuthLogonProof_C*>(GetDataBuffer());
// If the client has no valid version
if (_expversion == NO_VALID_EXP_FLAG)
@@ -522,17 +512,12 @@ bool AuthSession::HandleLogonProof()
// Check auth token
if ((logonProof->securityFlags & 0x04) || !_tokenKey.empty())
{
- // TODO To be fixed
-
- /*
- uint8 size;
- socket().recv((char*)&size, 1);
- char* token = new char[size + 1];
- token[size] = '\0';
- socket().recv(token, size);
- unsigned int validToken = TOTP::GenerateToken(_tokenKey.c_str());
- unsigned int incomingToken = atoi(token);
- delete[] token;
+ ReadData(1);
+ uint8 size = *(GetDataBuffer() + sizeof(sAuthLogonProof_C));
+ ReadData(size);
+ std::string token(reinterpret_cast<char*>(GetDataBuffer() + sizeof(sAuthLogonProof_C) + sizeof(size)), size);
+ uint32 validToken = TOTP::GenerateToken(_tokenKey.c_str());
+ uint32 incomingToken = atoi(token.c_str());
if (validToken != incomingToken)
{
ByteBuffer packet;
@@ -542,7 +527,7 @@ bool AuthSession::HandleLogonProof()
packet << uint8(0);
AsyncWrite(packet);
return false;
- }*/
+ }
}
ByteBuffer packet;
@@ -650,7 +635,7 @@ bool AuthSession::HandleLogonProof()
bool AuthSession::HandleReconnectChallenge()
{
TC_LOG_DEBUG("server.authserver", "Entering _HandleReconnectChallenge");
- sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer());
+ sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetDataBuffer());
//TC_LOG_DEBUG("server.authserver", "[AuthChallenge] got full packet, %#04x bytes", challenge->size);
TC_LOG_DEBUG("server.authserver", "[AuthChallenge] name(%d): '%s'", challenge->I_len, challenge->I);
@@ -701,7 +686,7 @@ bool AuthSession::HandleReconnectChallenge()
bool AuthSession::HandleReconnectProof()
{
TC_LOG_DEBUG("server.authserver", "Entering _HandleReconnectProof");
- sAuthReconnectProof_C *reconnectProof = reinterpret_cast<sAuthReconnectProof_C*>(GetReadBuffer());
+ sAuthReconnectProof_C *reconnectProof = reinterpret_cast<sAuthReconnectProof_C*>(GetDataBuffer());
if (_login.empty() || !_reconnectProof.GetNumBytes() || !K.GetNumBytes())
return false;
diff --git a/src/server/authserver/Server/AuthSession.h b/src/server/authserver/Server/AuthSession.h
index 5a05ee6f8e9..3497e3a030c 100644
--- a/src/server/authserver/Server/AuthSession.h
+++ b/src/server/authserver/Server/AuthSession.h
@@ -53,8 +53,8 @@ public:
void AsyncWrite(ByteBuffer& packet);
protected:
- void ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) override;
- void ReadDataHandler(boost::system::error_code error, size_t transferedBytes) override;
+ void ReadHeaderHandler() override;
+ void ReadDataHandler() override;
private:
bool HandleLogonChallenge();
diff --git a/src/server/game/Server/WorldSocket.cpp b/src/server/game/Server/WorldSocket.cpp
index 65a424d5d75..046cdc0acd3 100644
--- a/src/server/game/Server/WorldSocket.cpp
+++ b/src/server/game/Server/WorldSocket.cpp
@@ -54,89 +54,72 @@ void WorldSocket::HandleSendAuthSession()
AsyncWrite(packet);
}
-void WorldSocket::ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes)
+void WorldSocket::ReadHeaderHandler()
{
- if (!error && transferedBytes == sizeof(ClientPktHeader))
- {
- _authCrypt.DecryptRecv(GetReadBuffer(), sizeof(ClientPktHeader));
+ _authCrypt.DecryptRecv(GetHeaderBuffer(), sizeof(ClientPktHeader));
- ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(GetReadBuffer());
- EndianConvertReverse(header->size);
- EndianConvert(header->cmd);
+ ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(GetHeaderBuffer());
+ EndianConvertReverse(header->size);
+ EndianConvert(header->cmd);
- AsyncReadData(header->size - sizeof(header->cmd), sizeof(ClientPktHeader));
- }
- else
- CloseSocket();
+ AsyncReadData(header->size - sizeof(header->cmd));
}
-void WorldSocket::ReadDataHandler(boost::system::error_code error, size_t transferedBytes)
+void WorldSocket::ReadDataHandler()
{
- ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(GetReadBuffer());
+ ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(GetHeaderBuffer());
- if (!error && transferedBytes == (header->size - sizeof(header->cmd)))
- {
- header->size -= sizeof(header->cmd);
-
- uint16 opcode = uint16(header->cmd);
+ header->size -= sizeof(header->cmd);
- std::string opcodeName = GetOpcodeNameForLogging(opcode);
+ uint16 opcode = uint16(header->cmd);
- WorldPacket packet(opcode, header->size);
+ std::string opcodeName = GetOpcodeNameForLogging(opcode);
- if (header->size > 0)
- {
- packet.resize(header->size);
+ WorldPacket packet(opcode, MoveData());
- std::memcpy(packet.contents(), &(GetReadBuffer()[sizeof(ClientPktHeader)]), header->size);
- }
-
- if (sPacketLog->CanLogPacket())
- sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort());
+ if (sPacketLog->CanLogPacket())
+ sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort());
- TC_LOG_TRACE("network.opcode", "C->S: %s %s", (_worldSession ? _worldSession->GetPlayerInfo() : GetRemoteIpAddress().to_string()).c_str(), GetOpcodeNameForLogging(opcode).c_str());
+ TC_LOG_TRACE("network.opcode", "C->S: %s %s", (_worldSession ? _worldSession->GetPlayerInfo() : GetRemoteIpAddress().to_string()).c_str(), opcodeName.c_str());
- switch (opcode)
- {
- case CMSG_PING:
- HandlePing(packet);
+ switch (opcode)
+ {
+ case CMSG_PING:
+ HandlePing(packet);
+ break;
+ case CMSG_AUTH_SESSION:
+ if (_worldSession)
+ {
+ TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
break;
- case CMSG_AUTH_SESSION:
- if (_worldSession)
- {
- TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
- break;
- }
+ }
- sScriptMgr->OnPacketReceive(shared_from_this(), packet);
- HandleAuthSession(packet);
- break;
- case CMSG_KEEP_ALIVE:
- TC_LOG_DEBUG("network", "%s", opcodeName.c_str());
- sScriptMgr->OnPacketReceive(shared_from_this(), packet);
- break;
- default:
+ sScriptMgr->OnPacketReceive(shared_from_this(), packet);
+ HandleAuthSession(packet);
+ break;
+ case CMSG_KEEP_ALIVE:
+ TC_LOG_DEBUG("network", "%s", opcodeName.c_str());
+ sScriptMgr->OnPacketReceive(shared_from_this(), packet);
+ break;
+ default:
+ {
+ if (!_worldSession)
{
- if (!_worldSession)
- {
- TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode));
- break;
- }
-
- // Our Idle timer will reset on any non PING opcodes.
- // Catches people idling on the login screen and any lingering ingame connections.
- _worldSession->ResetTimeOutTime();
-
- // Copy the packet to the heap before enqueuing
- _worldSession->QueuePacket(new WorldPacket(packet));
+ TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode));
break;
}
- }
- AsyncReadHeader();
+ // Our Idle timer will reset on any non PING opcodes.
+ // Catches people idling on the login screen and any lingering ingame connections.
+ _worldSession->ResetTimeOutTime();
+
+ // Copy the packet to the heap before enqueuing
+ _worldSession->QueuePacket(new WorldPacket(std::move(packet)));
+ break;
+ }
}
- else
- CloseSocket();
+
+ AsyncReadHeader();
}
void WorldSocket::AsyncWrite(WorldPacket& packet)
diff --git a/src/server/game/Server/WorldSocket.h b/src/server/game/Server/WorldSocket.h
index 7275da5ff29..8d452677650 100644
--- a/src/server/game/Server/WorldSocket.h
+++ b/src/server/game/Server/WorldSocket.h
@@ -108,8 +108,8 @@ public:
void AsyncWrite(WorldPacket& packet);
protected:
- void ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) override;
- void ReadDataHandler(boost::system::error_code error, size_t transferedBytes) override;
+ void ReadHeaderHandler() override;
+ void ReadDataHandler() override;
private:
void HandleSendAuthSession();
diff --git a/src/server/shared/Common.h b/src/server/shared/Common.h
index 0a1389c1f38..ab268835046 100644
--- a/src/server/shared/Common.h
+++ b/src/server/shared/Common.h
@@ -19,46 +19,6 @@
#ifndef TRINITYCORE_COMMON_H
#define TRINITYCORE_COMMON_H
-// config.h needs to be included 1st
-/// @todo this thingy looks like hack, but its not, need to
-// make separate header however, because It makes mess here.
-#ifdef HAVE_CONFIG_H
-// Remove Some things that we will define
-// This is in case including another config.h
-// before trinity config.h
-#ifdef PACKAGE
-#undef PACKAGE
-#endif //PACKAGE
-#ifdef PACKAGE_BUGREPORT
-#undef PACKAGE_BUGREPORT
-#endif //PACKAGE_BUGREPORT
-#ifdef PACKAGE_NAME
-#undef PACKAGE_NAME
-#endif //PACKAGE_NAME
-#ifdef PACKAGE_STRING
-#undef PACKAGE_STRING
-#endif //PACKAGE_STRING
-#ifdef PACKAGE_TARNAME
-#undef PACKAGE_TARNAME
-#endif //PACKAGE_TARNAME
-#ifdef PACKAGE_VERSION
-#undef PACKAGE_VERSION
-#endif //PACKAGE_VERSION
-#ifdef VERSION
-#undef VERSION
-#endif //VERSION
-
-# include "Config.h"
-
-#undef PACKAGE
-#undef PACKAGE_BUGREPORT
-#undef PACKAGE_NAME
-#undef PACKAGE_STRING
-#undef PACKAGE_TARNAME
-#undef PACKAGE_VERSION
-#undef VERSION
-#endif //HAVE_CONFIG_H
-
#include "Define.h"
#include <unordered_map>
diff --git a/src/server/shared/Database/DatabaseWorkerPool.h b/src/server/shared/Database/DatabaseWorkerPool.h
index 39f1a8da3c2..e95dfc1e484 100644
--- a/src/server/shared/Database/DatabaseWorkerPool.h
+++ b/src/server/shared/Database/DatabaseWorkerPool.h
@@ -45,6 +45,14 @@ class PingOperation : public SQLOperation
template <class T>
class DatabaseWorkerPool
{
+ private:
+ enum InternalIndex
+ {
+ IDX_ASYNC,
+ IDX_SYNCH,
+ IDX_SIZE
+ };
+
public:
/* Activity state */
DatabaseWorkerPool() : _connectionInfo(NULL)
@@ -74,34 +82,17 @@ class DatabaseWorkerPool
TC_LOG_INFO("sql.driver", "Opening DatabasePool '%s'. Asynchronous connections: %u, synchronous connections: %u.",
GetDatabaseName(), async_threads, synch_threads);
- //! Open asynchronous connections (delayed operations)
- _connections[IDX_ASYNC].resize(async_threads);
- for (uint8 i = 0; i < async_threads; ++i)
- {
- T* t = new T(_queue, *_connectionInfo);
- res &= t->Open();
- if (res) // only check mysql version if connection is valid
- WPFatal(mysql_get_server_version(t->GetHandle()) >= MIN_MYSQL_SERVER_VERSION, "TrinityCore does not support MySQL versions below 5.1");
- _connections[IDX_ASYNC][i] = t;
- ++_connectionCount[IDX_ASYNC];
- }
+ res = OpenConnections(IDX_ASYNC, async_threads);
- //! Open synchronous connections (direct, blocking operations)
- _connections[IDX_SYNCH].resize(synch_threads);
- for (uint8 i = 0; i < synch_threads; ++i)
- {
- T* t = new T(*_connectionInfo);
- res &= t->Open();
- _connections[IDX_SYNCH][i] = t;
- ++_connectionCount[IDX_SYNCH];
- }
+ if (!res)
+ return res;
+
+ res = OpenConnections(IDX_SYNCH, synch_threads);
if (res)
TC_LOG_INFO("sql.driver", "DatabasePool '%s' opened successfully. %u total connections running.", GetDatabaseName(),
(_connectionCount[IDX_SYNCH] + _connectionCount[IDX_ASYNC]));
- else
- TC_LOG_ERROR("sql.driver", "DatabasePool %s NOT opened. There were errors opening the MySQL connections. Check your SQLDriverLogFile "
- "for specific errors. Read wiki at http://collab.kpsn.org/display/tc/TrinityCore+Home", GetDatabaseName());
+
return res;
}
@@ -112,8 +103,6 @@ class DatabaseWorkerPool
for (uint8 i = 0; i < _connectionCount[IDX_ASYNC]; ++i)
{
T* t = _connections[IDX_ASYNC][i];
- DatabaseWorker* worker = t->m_worker;
- delete worker;
t->Close(); //! Closes the actualy MySQL connection.
}
@@ -442,7 +431,7 @@ class DatabaseWorkerPool
if (str.empty())
return;
- char* buf = new char[str.size()*2+1];
+ char* buf = new char[str.size() * 2 + 1];
EscapeString(buf, str.c_str(), str.size());
str = buf;
delete[] buf;
@@ -470,6 +459,52 @@ class DatabaseWorkerPool
}
private:
+ bool OpenConnections(InternalIndex type, uint8 numConnections)
+ {
+ _connections[type].resize(numConnections);
+ for (uint8 i = 0; i < numConnections; ++i)
+ {
+ T* t;
+
+ if (type == IDX_ASYNC)
+ t = new T(_queue, *_connectionInfo);
+ else if (type == IDX_SYNCH)
+ t = new T(*_connectionInfo);
+
+ _connections[type][i] = t;
+ ++_connectionCount[type];
+
+ bool res = t->Open();
+
+ if (res)
+ {
+ if (mysql_get_server_version(t->GetHandle()) < MIN_MYSQL_SERVER_VERSION)
+ {
+ TC_LOG_ERROR("sql.driver", "TrinityCore does not support MySQL versions below 5.1");
+ res = false;
+ }
+ }
+
+ // Failed to open a connection or invalid version, abort and cleanup
+ if (!res)
+ {
+ TC_LOG_ERROR("sql.driver", "DatabasePool %s NOT opened. There were errors opening the MySQL connections. Check your SQLDriverLogFile "
+ "for specific errors. Read wiki at http://collab.kpsn.org/display/tc/TrinityCore+Home", GetDatabaseName());
+
+ while (_connectionCount[type] != 0)
+ {
+ T* t = _connections[type][i--];
+ delete t;
+ --_connectionCount[type];
+ }
+
+ return false;
+ }
+ }
+
+ return true;
+ }
+
unsigned long EscapeString(char *to, const char *from, unsigned long length)
{
if (!to || !from || !length)
@@ -507,14 +542,6 @@ class DatabaseWorkerPool
return _connectionInfo->database.c_str();
}
- private:
- enum _internalIndex
- {
- IDX_ASYNC,
- IDX_SYNCH,
- IDX_SIZE
- };
-
ProducerConsumerQueue<SQLOperation*>* _queue; //! Queue shared by async worker threads.
std::vector< std::vector<T*> > _connections;
uint32 _connectionCount[2]; //! Counter of MySQL connections;
diff --git a/src/server/shared/Database/MySQLConnection.cpp b/src/server/shared/Database/MySQLConnection.cpp
index e9fc20aef82..4e46ff0e3a1 100644
--- a/src/server/shared/Database/MySQLConnection.cpp
+++ b/src/server/shared/Database/MySQLConnection.cpp
@@ -57,12 +57,14 @@ m_connectionFlags(CONNECTION_ASYNC)
MySQLConnection::~MySQLConnection()
{
- ASSERT (m_Mysql); /// MySQL context must be present at this point
-
for (size_t i = 0; i < m_stmts.size(); ++i)
delete m_stmts[i];
- mysql_close(m_Mysql);
+ if (m_Mysql)
+ mysql_close(m_Mysql);
+
+ if (m_worker)
+ delete m_worker;
}
void MySQLConnection::Close()
diff --git a/src/server/shared/Networking/MessageBuffer.h b/src/server/shared/Networking/MessageBuffer.h
new file mode 100644
index 00000000000..fff94b86c1e
--- /dev/null
+++ b/src/server/shared/Networking/MessageBuffer.h
@@ -0,0 +1,93 @@
+/*
+* Copyright (C) 2008-2014 TrinityCore <http://www.trinitycore.org/>
+*
+* This program is free software; you can redistribute it and/or modify it
+* under the terms of the GNU General Public License as published by the
+* Free Software Foundation; either version 2 of the License, or (at your
+* option) any later version.
+*
+* This program is distributed in the hope that it will be useful, but WITHOUT
+* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+* more details.
+*
+* You should have received a copy of the GNU General Public License along
+* with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+#ifndef __MESSAGEBUFFER_H_
+#define __MESSAGEBUFFER_H_
+
+#include "Define.h"
+#include <vector>
+
+class MessageBuffer
+{
+ typedef std::vector<uint8>::size_type size_type;
+
+public:
+ MessageBuffer() : _wpos(0), _storage() { }
+
+ MessageBuffer(MessageBuffer const& right) : _wpos(right._wpos), _storage(right._storage) { }
+
+ MessageBuffer(MessageBuffer&& right) : _wpos(right._wpos), _storage(right.Move()) { }
+
+ void Reset()
+ {
+ _storage.clear();
+ _wpos = 0;
+ }
+
+ bool IsMessageReady() const { return _wpos == _storage.size(); }
+
+ size_type GetMissingSize() const { return _storage.size() - _wpos; }
+
+ uint8* Data() { return _storage.data(); }
+
+ void Grow(size_type bytes)
+ {
+ _storage.resize(_storage.size() + bytes);
+ }
+
+ uint8* GetWritePointer() { return &_storage[_wpos]; }
+
+ void WriteCompleted(size_type bytes) { _wpos += bytes; }
+
+ void ResetWritePointer() { _wpos = 0; }
+
+ size_type GetSize() { return _storage.size(); }
+
+ std::vector<uint8>&& Move()
+ {
+ _wpos = 0;
+ return std::move(_storage);
+ }
+
+ MessageBuffer& operator=(MessageBuffer& right)
+ {
+ if (this != &right)
+ {
+ _wpos = right._wpos;
+ _storage = right._storage;
+ }
+
+ return *this;
+ }
+
+ MessageBuffer& operator=(MessageBuffer&& right)
+ {
+ if (this != &right)
+ {
+ _wpos = right._wpos;
+ _storage = right.Move();
+ }
+
+ return *this;
+ }
+
+private:
+ size_type _wpos;
+ std::vector<uint8> _storage;
+};
+
+#endif /* __MESSAGEBUFFER_H_ */
diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h
index 6bf67e06d9c..f86890ed7ef 100644
--- a/src/server/shared/Networking/Socket.h
+++ b/src/server/shared/Networking/Socket.h
@@ -18,8 +18,9 @@
#ifndef __SOCKET_H__
#define __SOCKET_H__
-#include "Define.h"
+#include "MessageBuffer.h"
#include "Log.h"
+#include <atomic>
#include <vector>
#include <mutex>
#include <queue>
@@ -28,19 +29,22 @@
#include <type_traits>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/write.hpp>
+#include <boost/asio/read.hpp>
using boost::asio::ip::tcp;
+#define READ_BLOCK_SIZE 4096
+
template<class T, class PacketType>
class Socket : public std::enable_shared_from_this<T>
{
typedef typename std::conditional<std::is_pointer<PacketType>::value, PacketType, PacketType const&>::type WritePacketType;
public:
- Socket(tcp::socket&& socket, std::size_t headerSize) : _socket(std::move(socket)), _headerSize(headerSize)
+ Socket(tcp::socket&& socket, std::size_t headerSize) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
+ _remotePort(_socket.remote_endpoint().port()), _readHeaderBuffer(), _readDataBuffer(), _closed(false)
{
- _remotePort = _socket.remote_endpoint().port();
- _remoteAddress = _socket.remote_endpoint().address();
+ _readHeaderBuffer.Grow(headerSize);
}
virtual void Start() = 0;
@@ -57,25 +61,39 @@ public:
void AsyncReadHeader()
{
- _socket.async_read_some(boost::asio::buffer(_readBuffer, _headerSize), std::bind(&Socket<T, PacketType>::ReadHeaderHandlerInternal, this->shared_from_this(),
- std::placeholders::_1, std::placeholders::_2));
+ _readHeaderBuffer.ResetWritePointer();
+ _readDataBuffer.Reset();
+
+ AsyncReadMissingHeaderData();
}
- void AsyncReadData(std::size_t size, std::size_t bufferOffset)
+ void AsyncReadData(std::size_t size)
{
- _socket.async_read_some(boost::asio::buffer(&_readBuffer[bufferOffset], size), std::bind(&Socket<T, PacketType>::ReadDataHandlerInternal, this->shared_from_this(),
- std::placeholders::_1, std::placeholders::_2));
+ if (!size)
+ {
+ // if this is a packet with 0 length body just invoke handler directly
+ ReadDataHandler();
+ return;
+ }
+
+ _readDataBuffer.Grow(size);
+ AsyncReadMissingData();
}
- void ReadData(std::size_t size, std::size_t bufferOffset)
+ void ReadData(std::size_t size)
{
boost::system::error_code error;
- _socket.read_some(boost::asio::buffer(&_readBuffer[bufferOffset], size), error);
+ _readDataBuffer.Grow(size);
- if (error)
+ std::size_t bytesRead = boost::asio::read(_socket, boost::asio::buffer(_readDataBuffer.GetWritePointer(), size), error);
+
+ _readDataBuffer.WriteCompleted(bytesRead);
+
+ if (error || !_readDataBuffer.IsMessageReady())
{
- TC_LOG_DEBUG("network", "Socket::ReadData: %s errored with: %i (%s)", GetRemoteIpAddress().to_string().c_str(), error.value(), error.message().c_str());
+ TC_LOG_DEBUG("network", "Socket::ReadData: %s errored with: %i (%s)", GetRemoteIpAddress().to_string().c_str(), error.value(),
+ error.message().c_str());
CloseSocket();
}
@@ -83,18 +101,22 @@ public:
void AsyncWrite(WritePacketType data)
{
- boost::asio::async_write(_socket, boost::asio::buffer(data), std::bind(&Socket<T, PacketType>::WriteHandler, this->shared_from_this(), std::placeholders::_1,
- std::placeholders::_2));
+ boost::asio::async_write(_socket, boost::asio::buffer(data), std::bind(&Socket<T, PacketType>::WriteHandler, this->shared_from_this(),
+ std::placeholders::_1, std::placeholders::_2));
}
- bool IsOpen() const { return _socket.is_open(); }
+ bool IsOpen() const { return !_closed; }
+
void CloseSocket()
{
+ if (_closed.exchange(true))
+ return;
+
boost::system::error_code shutdownError;
_socket.shutdown(boost::asio::socket_base::shutdown_both, shutdownError);
if (shutdownError)
TC_LOG_DEBUG("network", "Socket::CloseSocket: %s errored when shutting down socket: %i (%s)", GetRemoteIpAddress().to_string().c_str(),
- shutdownError.value(), shutdownError.message().c_str());
+ shutdownError.value(), shutdownError.message().c_str());
boost::system::error_code error;
_socket.close(error);
@@ -103,18 +125,72 @@ public:
error.value(), error.message().c_str());
}
- uint8* GetReadBuffer() { return _readBuffer; }
+ virtual bool IsHeaderReady() const { return _readHeaderBuffer.IsMessageReady(); }
+ virtual bool IsDataReady() const { return _readDataBuffer.IsMessageReady(); }
+
+ uint8* GetHeaderBuffer() { return _readHeaderBuffer.Data(); }
+ uint8* GetDataBuffer() { return _readDataBuffer.Data(); }
+
+ MessageBuffer&& MoveHeader() { return std::move(_readHeaderBuffer); }
+ MessageBuffer&& MoveData() { return std::move(_readDataBuffer); }
protected:
- virtual void ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) = 0;
- virtual void ReadDataHandler(boost::system::error_code error, size_t transferedBytes) = 0;
+ virtual void ReadHeaderHandler() = 0;
+ virtual void ReadDataHandler() = 0;
std::mutex _writeLock;
std::queue<PacketType> _writeQueue;
private:
- void ReadHeaderHandlerInternal(boost::system::error_code error, size_t transferedBytes) { ReadHeaderHandler(error, transferedBytes); }
- void ReadDataHandlerInternal(boost::system::error_code error, size_t transferedBytes) { ReadDataHandler(error, transferedBytes); }
+ void AsyncReadMissingHeaderData()
+ {
+ _socket.async_read_some(boost::asio::buffer(_readHeaderBuffer.GetWritePointer(), std::min<std::size_t>(READ_BLOCK_SIZE, _readHeaderBuffer.GetMissingSize())),
+ std::bind(&Socket<T, PacketType>::ReadHeaderHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
+ }
+
+ void AsyncReadMissingData()
+ {
+ _socket.async_read_some(boost::asio::buffer(_readDataBuffer.GetWritePointer(), std::min<std::size_t>(READ_BLOCK_SIZE, _readDataBuffer.GetMissingSize())),
+ std::bind(&Socket<T, PacketType>::ReadDataHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
+ }
+
+ void ReadHeaderHandlerInternal(boost::system::error_code error, size_t transferredBytes)
+ {
+ if (error)
+ {
+ CloseSocket();
+ return;
+ }
+
+ _readHeaderBuffer.WriteCompleted(transferredBytes);
+ if (!IsHeaderReady())
+ {
+ // incomplete, read more
+ AsyncReadMissingHeaderData();
+ return;
+ }
+
+ ReadHeaderHandler();
+ }
+
+ void ReadDataHandlerInternal(boost::system::error_code error, size_t transferredBytes)
+ {
+ if (error)
+ {
+ CloseSocket();
+ return;
+ }
+
+ _readDataBuffer.WriteCompleted(transferredBytes);
+ if (!IsDataReady())
+ {
+ // incomplete, read more
+ AsyncReadMissingData();
+ return;
+ }
+
+ ReadDataHandler();
+ }
void WriteHandler(boost::system::error_code error, size_t /*transferedBytes*/)
{
@@ -140,12 +216,13 @@ private:
tcp::socket _socket;
- uint8 _readBuffer[4096];
-
- uint16 _remotePort;
boost::asio::ip::address _remoteAddress;
+ uint16 _remotePort;
+
+ MessageBuffer _readHeaderBuffer;
+ MessageBuffer _readDataBuffer;
- std::size_t _headerSize;
+ std::atomic<bool> _closed;
};
#endif // __SOCKET_H__
diff --git a/src/server/shared/Packets/ByteBuffer.cpp b/src/server/shared/Packets/ByteBuffer.cpp
index 86234039a4a..3785d1c29fa 100644
--- a/src/server/shared/Packets/ByteBuffer.cpp
+++ b/src/server/shared/Packets/ByteBuffer.cpp
@@ -17,11 +17,16 @@
*/
#include "ByteBuffer.h"
+#include "MessageBuffer.h"
#include "Common.h"
#include "Log.h"
#include <sstream>
+ByteBuffer::ByteBuffer(MessageBuffer&& buffer) : _rpos(0), _wpos(0), _storage(buffer.Move())
+{
+}
+
ByteBufferPositionException::ByteBufferPositionException(bool add, size_t pos,
size_t size, size_t valueSize)
{
diff --git a/src/server/shared/Packets/ByteBuffer.h b/src/server/shared/Packets/ByteBuffer.h
index c678e9dce06..456223d744d 100644
--- a/src/server/shared/Packets/ByteBuffer.h
+++ b/src/server/shared/Packets/ByteBuffer.h
@@ -34,6 +34,8 @@
#include <math.h>
#include <boost/asio/buffer.hpp>
+class MessageBuffer;
+
// Root of ByteBuffer exception hierarchy
class ByteBufferException : public std::exception
{
@@ -82,14 +84,12 @@ class ByteBuffer
}
ByteBuffer(ByteBuffer&& buf) : _rpos(buf._rpos), _wpos(buf._wpos),
- _storage(std::move(buf._storage))
- {
- }
+ _storage(std::move(buf._storage)) { }
ByteBuffer(ByteBuffer const& right) : _rpos(right._rpos), _wpos(right._wpos),
- _storage(right._storage)
- {
- }
+ _storage(right._storage) { }
+
+ ByteBuffer(MessageBuffer&& buffer);
ByteBuffer& operator=(ByteBuffer const& right)
{
diff --git a/src/server/shared/Packets/WorldPacket.h b/src/server/shared/Packets/WorldPacket.h
index 8851b9f3e45..848a00739fe 100644
--- a/src/server/shared/Packets/WorldPacket.h
+++ b/src/server/shared/Packets/WorldPacket.h
@@ -51,6 +51,8 @@ class WorldPacket : public ByteBuffer
return *this;
}
+ WorldPacket(uint16 opcode, MessageBuffer&& buffer) : ByteBuffer(std::move(buffer)), m_opcode(opcode) { }
+
void Initialize(uint16 opcode, size_t newres=200)
{
clear();
diff --git a/src/server/worldserver/Main.cpp b/src/server/worldserver/Main.cpp
index 3afa9e84e8b..2c393215f7d 100644
--- a/src/server/worldserver/Main.cpp
+++ b/src/server/worldserver/Main.cpp
@@ -87,6 +87,7 @@ bool StartDB();
void StopDB();
void WorldUpdateLoop();
void ClearOnlineAccounts();
+void ShutdownThreadPool(std::vector<std::thread>& threadPool);
variables_map GetConsoleArguments(int argc, char** argv, std::string& cfg_file, std::string& cfg_service);
/// Launch the Trinity server
@@ -179,7 +180,10 @@ extern int main(int argc, char** argv)
// Start the databases
if (!StartDB())
+ {
+ ShutdownThreadPool(threadPool);
return 1;
+ }
// Set server offline (not connectable)
LoginDatabase.DirectPExecute("UPDATE realmlist SET flag = (flag & ~%u) | %u WHERE id = '%d'", REALM_FLAG_OFFLINE, REALM_FLAG_INVALID, realmID);
@@ -236,13 +240,7 @@ extern int main(int argc, char** argv)
WorldUpdateLoop();
// Shutdown starts here
-
- _ioService.stop();
-
- for (auto& thread : threadPool)
- {
- thread.join();
- }
+ ShutdownThreadPool(threadPool);
sScriptMgr->OnShutdown();
@@ -281,41 +279,7 @@ extern int main(int argc, char** argv)
if (cliThread != nullptr)
{
#ifdef _WIN32
-
- // this only way to terminate CLI thread exist at Win32 (alt. way exist only in Windows Vista API)
- //_exit(1);
- // send keyboard input to safely unblock the CLI thread
- INPUT_RECORD b[4];
- HANDLE hStdIn = GetStdHandle(STD_INPUT_HANDLE);
- b[0].EventType = KEY_EVENT;
- b[0].Event.KeyEvent.bKeyDown = TRUE;
- b[0].Event.KeyEvent.uChar.AsciiChar = 'X';
- b[0].Event.KeyEvent.wVirtualKeyCode = 'X';
- b[0].Event.KeyEvent.wRepeatCount = 1;
-
- b[1].EventType = KEY_EVENT;
- b[1].Event.KeyEvent.bKeyDown = FALSE;
- b[1].Event.KeyEvent.uChar.AsciiChar = 'X';
- b[1].Event.KeyEvent.wVirtualKeyCode = 'X';
- b[1].Event.KeyEvent.wRepeatCount = 1;
-
- b[2].EventType = KEY_EVENT;
- b[2].Event.KeyEvent.bKeyDown = TRUE;
- b[2].Event.KeyEvent.dwControlKeyState = 0;
- b[2].Event.KeyEvent.uChar.AsciiChar = '\r';
- b[2].Event.KeyEvent.wVirtualKeyCode = VK_RETURN;
- b[2].Event.KeyEvent.wRepeatCount = 1;
- b[2].Event.KeyEvent.wVirtualScanCode = 0x1c;
-
- b[3].EventType = KEY_EVENT;
- b[3].Event.KeyEvent.bKeyDown = FALSE;
- b[3].Event.KeyEvent.dwControlKeyState = 0;
- b[3].Event.KeyEvent.uChar.AsciiChar = '\r';
- b[3].Event.KeyEvent.wVirtualKeyCode = VK_RETURN;
- b[3].Event.KeyEvent.wVirtualScanCode = 0x1c;
- b[3].Event.KeyEvent.wRepeatCount = 1;
- DWORD numb;
- WriteConsoleInput(hStdIn, b, 4, &numb);
+ CancelSynchronousIo(cliThread->native_handle());
#endif
cliThread->join();
delete cliThread;
@@ -330,6 +294,15 @@ extern int main(int argc, char** argv)
return World::GetExitCode();
}
+void ShutdownThreadPool(std::vector<std::thread>& threadPool)
+{
+ _ioService.stop();
+
+ for (auto& thread : threadPool)
+ {
+ thread.join();
+ }
+}
void WorldUpdateLoop()
{
diff --git a/src/tools/mmaps_generator/PathGenerator.cpp b/src/tools/mmaps_generator/PathGenerator.cpp
index 3e2025dace8..c2ca184905e 100644
--- a/src/tools/mmaps_generator/PathGenerator.cpp
+++ b/src/tools/mmaps_generator/PathGenerator.cpp
@@ -276,7 +276,7 @@ int main(int argc, char** argv)
}
if (!checkDirectories(debugOutput))
- return silent ? -3 : finish("Press any key to close...", -3);
+ return silent ? -3 : finish("Press ENTER to close...", -3);
MapBuilder builder(maxAngle, skipLiquid, skipContinents, skipJunkMaps,
skipBattlegrounds, debugOutput, bigBaseUnit, offMeshInputPath);