diff options
author | Shauren <shauren.trinity@gmail.com> | 2024-09-22 13:17:08 +0200 |
---|---|---|
committer | Shauren <shauren.trinity@gmail.com> | 2024-09-22 13:17:08 +0200 |
commit | 723e638a84d2aebe8a41fc39ba94f1a04138797e (patch) | |
tree | 691e6c1614462e5fcce4b05d674cba38edd3f748 /src | |
parent | 7eb3189a8046d35de997a24feabab7bf2f156fe0 (diff) |
Core/Spells: Protect against stack overflows in spell override handling
Diffstat (limited to 'src')
-rw-r--r-- | src/server/game/Entities/Player/Player.cpp | 12 | ||||
-rw-r--r-- | src/server/game/Entities/Player/Player.h | 2 | ||||
-rw-r--r-- | src/server/game/Entities/Unit/Unit.cpp | 24 | ||||
-rw-r--r-- | src/server/game/Entities/Unit/Unit.h | 7 |
4 files changed, 33 insertions, 12 deletions
diff --git a/src/server/game/Entities/Player/Player.cpp b/src/server/game/Entities/Player/Player.cpp index ff98f642fc0..46e357e9b60 100644 --- a/src/server/game/Entities/Player/Player.cpp +++ b/src/server/game/Entities/Player/Player.cpp @@ -30035,15 +30035,16 @@ Difficulty Player::CheckLoadedLegacyRaidDifficultyID(Difficulty difficulty) return difficulty; } -SpellInfo const* Player::GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag) const +SpellInfo const* Player::GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag, GetCastSpellInfoContext* context) const { auto overrides = m_overrideSpells.find(spellInfo->Id); if (overrides != m_overrideSpells.end()) for (uint32 spellId : overrides->second) - if (SpellInfo const* newInfo = sSpellMgr->GetSpellInfo(spellId, GetMap()->GetDifficultyID())) - return GetCastSpellInfo(newInfo, triggerFlag); + if (context->AddSpell(spellId)) + if (SpellInfo const* newInfo = sSpellMgr->GetSpellInfo(spellId, GetMap()->GetDifficultyID())) + return GetCastSpellInfo(newInfo, triggerFlag, context); - return Unit::GetCastSpellInfo(spellInfo, triggerFlag); + return Unit::GetCastSpellInfo(spellInfo, triggerFlag, context); } void Player::AddOverrideSpell(uint32 overridenSpellId, uint32 newSpellId) @@ -30671,7 +30672,8 @@ void Player::ExecutePendingSpellCastRequest() } // Check possible spell cast overrides - spellInfo = castingUnit->GetCastSpellInfo(spellInfo, triggerFlag); + GetCastSpellInfoContext overrideContext; + spellInfo = castingUnit->GetCastSpellInfo(spellInfo, triggerFlag, &overrideContext); if (spellInfo->IsPassive()) { CancelPendingCastRequest(); diff --git a/src/server/game/Entities/Player/Player.h b/src/server/game/Entities/Player/Player.h index 5dbc2e8c305..957bf84d891 100644 --- a/src/server/game/Entities/Player/Player.h +++ b/src/server/game/Entities/Player/Player.h @@ -1872,7 +1872,7 @@ class TC_GAME_API Player final : public Unit, public GridObject<Player> void SendRemoveControlBar() const; bool HasSpell(uint32 spell) const override; bool HasActiveSpell(uint32 spell) const; // show in spellbook - SpellInfo const* GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag) const override; + SpellInfo const* GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag, GetCastSpellInfoContext* context) const override; bool IsSpellFitByClassAndRace(uint32 spell_id) const; bool HandlePassiveSpellLearn(SpellInfo const* spellInfo); diff --git a/src/server/game/Entities/Unit/Unit.cpp b/src/server/game/Entities/Unit/Unit.cpp index 09fd65098c0..b709f973db2 100644 --- a/src/server/game/Entities/Unit/Unit.cpp +++ b/src/server/game/Entities/Unit/Unit.cpp @@ -13889,14 +13889,28 @@ void Unit::ClearBossEmotes(Optional<uint32> zoneId, Player const* target) const ref.GetSource()->SendDirectMessage(clearBossEmotes.GetRawPacket()); } -SpellInfo const* Unit::GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag) const +bool Unit::GetCastSpellInfoContext::AddSpell(uint32 spellId) { - auto findMatchingAuraEffectIn = [this, spellInfo, &triggerFlag](AuraType type) -> SpellInfo const* + auto itr = std::ranges::find(VisitedSpells, spellId); + if (itr != VisitedSpells.end()) + return false; // already exists + + itr = std::ranges::find(VisitedSpells, 0u); + if (itr == VisitedSpells.end()) + return false; // no free slots left + + *itr = spellId; + return true; +} + +SpellInfo const* Unit::GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag, GetCastSpellInfoContext* context) const +{ + auto findMatchingAuraEffectIn = [this, spellInfo, &triggerFlag, context](AuraType type) -> SpellInfo const* { for (AuraEffect const* auraEffect : GetAuraEffectsByType(type)) { bool matches = auraEffect->GetMiscValue() ? uint32(auraEffect->GetMiscValue()) == spellInfo->Id : auraEffect->IsAffectingSpell(spellInfo); - if (matches) + if (matches && context->AddSpell(auraEffect->GetAmount())) { if (SpellInfo const* newInfo = sSpellMgr->GetSpellInfo(auraEffect->GetAmount(), GetMap()->GetDifficultyID())) { @@ -13921,13 +13935,13 @@ SpellInfo const* Unit::GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastF if (SpellInfo const* newInfo = findMatchingAuraEffectIn(SPELL_AURA_OVERRIDE_ACTIONBAR_SPELLS)) { triggerFlag &= ~TRIGGERED_IGNORE_CAST_TIME; - return GetCastSpellInfo(newInfo, triggerFlag); + return GetCastSpellInfo(newInfo, triggerFlag, context); } if (SpellInfo const* newInfo = findMatchingAuraEffectIn(SPELL_AURA_OVERRIDE_ACTIONBAR_SPELLS_TRIGGERED)) { triggerFlag |= TRIGGERED_IGNORE_CAST_TIME; - return GetCastSpellInfo(newInfo, triggerFlag); + return GetCastSpellInfo(newInfo, triggerFlag, context); } return spellInfo; diff --git a/src/server/game/Entities/Unit/Unit.h b/src/server/game/Entities/Unit/Unit.h index b60317dd891..47dbe3498bb 100644 --- a/src/server/game/Entities/Unit/Unit.h +++ b/src/server/game/Entities/Unit/Unit.h @@ -1452,7 +1452,12 @@ class TC_GAME_API Unit : public WorldObject Spell* GetCurrentSpell(uint32 spellType) const { return m_currentSpells[spellType]; } Spell* FindCurrentSpellBySpellId(uint32 spell_id) const; int32 GetCurrentSpellCastTime(uint32 spell_id) const; - virtual SpellInfo const* GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag) const; + struct GetCastSpellInfoContext + { + std::array<uint32, 5> VisitedSpells = { }; + bool AddSpell(uint32 spellId); + }; + virtual SpellInfo const* GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag, GetCastSpellInfoContext* context) const; uint32 GetCastSpellXSpellVisualId(SpellInfo const* spellInfo) const override; virtual bool HasSpellFocus(Spell const* /*focusSpell*/ = nullptr) const { return false; } |