diff options
Diffstat (limited to 'src/server/bnetserver/Server/Session.cpp')
-rw-r--r-- | src/server/bnetserver/Server/Session.cpp | 103 |
1 files changed, 33 insertions, 70 deletions
diff --git a/src/server/bnetserver/Server/Session.cpp b/src/server/bnetserver/Server/Session.cpp index c3a8ae6f211..7b9d6f45a3e 100644 --- a/src/server/bnetserver/Server/Session.cpp +++ b/src/server/bnetserver/Server/Session.cpp @@ -23,6 +23,7 @@ #include "Errors.h" #include "Hash.h" #include "IPLocation.h" +#include "IpBanCheckConnectionInitializer.h" #include "LoginRESTService.h" #include "MapUtils.h" #include "ProtobufJSON.h" @@ -74,7 +75,7 @@ void Battlenet::Session::GameAccountInfo::LoadResult(Field const* fields) DisplayName = Name; } -Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSocket(std::move(socket), SslContext::instance()), +Battlenet::Session::Session(Trinity::Net::IoContextTcpSocket&& socket) : BaseSocket(std::move(socket), SslContext::instance()), _accountInfo(new AccountInfo()), _gameAccountInfo(nullptr), _locale(), _os(), _build(0), _clientInfo(), _timezoneOffset(0min), _ipCountry(), _clientSecret(), _authed(false), _requestToken(0) { @@ -83,54 +84,24 @@ Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSo Battlenet::Session::~Session() = default; -void Battlenet::Session::AsyncHandshake() -{ - underlying_stream().async_handshake(boost::asio::ssl::stream_base::server, - [sess = shared_from_this()](boost::system::error_code const& error) { sess->HandshakeHandler(error); }); -} - void Battlenet::Session::Start() { - std::string ip_address = GetRemoteIpAddress().to_string(); TC_LOG_TRACE("session", "{} Accepted connection", GetClientInfo()); - // Verify that this IP is not in the ip_banned table - LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS)); + // build initializer chain + std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3> initializers = + { { + std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<Session>>(this), + std::make_shared<Trinity::Net::SslHandshakeConnectionInitializer<Session>>(this), + std::make_shared<Trinity::Net::ReadConnectionInitializer<Session>>(this), + } }; - LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); - stmt->setString(0, ip_address); - - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt) - .WithPreparedCallback([sess = shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); })); -} - -void Battlenet::Session::CheckIpCallback(PreparedQueryResult result) -{ - if (result) - { - bool banned = false; - do - { - Field* fields = result->Fetch(); - if (fields[0].GetUInt64() != 0) - banned = true; - - } while (result->NextRow()); - - if (banned) - { - TC_LOG_DEBUG("session", "{} tries to log in using banned IP!", GetClientInfo()); - CloseSocket(); - return; - } - } - - AsyncHandshake(); + Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start(); } bool Battlenet::Session::Update() { - if (!BattlenetSocket::Update()) + if (!BaseSocket::Update()) return false; _queryProcessor.ProcessReadyCallbacks(); @@ -211,6 +182,11 @@ void Battlenet::Session::SendRequest(uint32 serviceHash, uint32 methodId, pb::Me AsyncWrite(&packet); } +void Battlenet::Session::QueueQuery(QueryCallback&& queryCallback) +{ + _queryProcessor.AddCallback(std::move(queryCallback)); +} + uint32 Battlenet::Session::HandleLogon(authentication::v1::LogonRequest const* logonRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation) { if (logonRequest->program() != "WoW") @@ -292,7 +268,7 @@ uint32 Battlenet::Session::HandleGenerateWebCredentials(authentication::v1::Gene LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_BNET_EXISTING_AUTHENTICATION_BY_ID); stmt->setUInt32(0, _accountInfo->Id); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result) + QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result) { // just send existing credentials back (not the best but it works for now with them being stored in db) Battlenet::Services::Authentication asyncContinuationService(this); @@ -314,7 +290,7 @@ uint32 Battlenet::Session::VerifyWebCredentials(std::string const& webCredential std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)> asyncContinuation = std::move(continuation); std::shared_ptr<AccountInfo> accountInfo = std::make_shared<AccountInfo>(); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result) + QueueQuery(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result) { Battlenet::Services::Authentication asyncContinuationService(this); NoData response; @@ -713,20 +689,8 @@ uint32 Battlenet::Session::HandleGetAllValuesForAttribute(game_utilities::v1::Ge return ERROR_RPC_NOT_IMPLEMENTED; } -void Battlenet::Session::HandshakeHandler(boost::system::error_code const& error) -{ - if (error) - { - TC_LOG_ERROR("session", "{} SSL Handshake failed {}", GetClientInfo(), error.message()); - CloseSocket(); - return; - } - - AsyncRead(); -} - template<bool(Battlenet::Session::*processMethod)(), MessageBuffer Battlenet::Session::*outputBuffer> -inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer) +static inline Optional<Trinity::Net::SocketReadCallbackResult> PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer) { MessageBuffer& buffer = session->*outputBuffer; @@ -738,46 +702,45 @@ inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inp buffer.Write(inputBuffer.GetReadPointer(), readDataSize); inputBuffer.ReadCompleted(readDataSize); } + else + return { }; // go to next buffer if (buffer.GetRemainingSpace() > 0) { // Couldn't receive the whole data this time. ASSERT(inputBuffer.GetActiveSize() == 0); - return false; + return Trinity::Net::SocketReadCallbackResult::KeepReading; } // just received fresh new payload if (!(session->*processMethod)()) { session->CloseSocket(); - return false; + return Trinity::Net::SocketReadCallbackResult::Stop; } - return true; + return { }; // go to next buffer } -void Battlenet::Session::ReadHandler() +Trinity::Net::SocketReadCallbackResult Battlenet::Session::ReadHandler() { - if (!IsOpen()) - return; - MessageBuffer& packet = GetReadBuffer(); while (packet.GetActiveSize() > 0) { - if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderLengthHandler, &Battlenet::Session::_headerLengthBuffer>(this, packet)) - break; + if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderLengthHandler, &Session::_headerLengthBuffer>(this, packet)) + return *partialResult; - if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderHandler, &Battlenet::Session::_headerBuffer>(this, packet)) - break; + if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderHandler, &Session::_headerBuffer>(this, packet)) + return *partialResult; - if (!PartialProcessPacket<&Battlenet::Session::ReadDataHandler, &Battlenet::Session::_packetBuffer>(this, packet)) - break; + if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadDataHandler, &Session::_packetBuffer>(this, packet)) + return *partialResult; _headerLengthBuffer.Reset(); _headerBuffer.Reset(); } - AsyncRead(); + return Trinity::Net::SocketReadCallbackResult::KeepReading; } bool Battlenet::Session::ReadHeaderLengthHandler() @@ -835,5 +798,5 @@ std::string Battlenet::Session::GetClientInfo() const stream << ']'; - return stream.str(); + return std::move(stream).str(); } |