Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1 Core/Spells: Protect against stack overflows in spell override handling #322

Merged
merged 1 commit into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/server/game/Entities/Player/Player.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30035,15 +30035,16 @@ Difficulty Player::CheckLoadedLegacyRaidDifficultyID(Difficulty difficulty)
return difficulty;
}

SpellInfo const* Player::GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag) const
SpellInfo const* Player::GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag, GetCastSpellInfoContext* context) const
{
auto overrides = m_overrideSpells.find(spellInfo->Id);
if (overrides != m_overrideSpells.end())
for (uint32 spellId : overrides->second)
if (SpellInfo const* newInfo = sSpellMgr->GetSpellInfo(spellId, GetMap()->GetDifficultyID()))
return GetCastSpellInfo(newInfo, triggerFlag);
if (context->AddSpell(spellId))
if (SpellInfo const* newInfo = sSpellMgr->GetSpellInfo(spellId, GetMap()->GetDifficultyID()))
return GetCastSpellInfo(newInfo, triggerFlag, context);

return Unit::GetCastSpellInfo(spellInfo, triggerFlag);
return Unit::GetCastSpellInfo(spellInfo, triggerFlag, context);
}

void Player::AddOverrideSpell(uint32 overridenSpellId, uint32 newSpellId)
Expand Down Expand Up @@ -30671,7 +30672,8 @@ void Player::ExecutePendingSpellCastRequest()
}

// Check possible spell cast overrides
spellInfo = castingUnit->GetCastSpellInfo(spellInfo, triggerFlag);
GetCastSpellInfoContext overrideContext;
spellInfo = castingUnit->GetCastSpellInfo(spellInfo, triggerFlag, &overrideContext);
if (spellInfo->IsPassive())
{
CancelPendingCastRequest();
Expand Down
2 changes: 1 addition & 1 deletion src/server/game/Entities/Player/Player.h
Original file line number Diff line number Diff line change
Expand Up @@ -1872,7 +1872,7 @@ class TC_GAME_API Player final : public Unit, public GridObject<Player>
void SendRemoveControlBar() const;
bool HasSpell(uint32 spell) const override;
bool HasActiveSpell(uint32 spell) const; // show in spellbook
SpellInfo const* GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag) const override;
SpellInfo const* GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag, GetCastSpellInfoContext* context) const override;
bool IsSpellFitByClassAndRace(uint32 spell_id) const;
bool HandlePassiveSpellLearn(SpellInfo const* spellInfo);

Expand Down
24 changes: 19 additions & 5 deletions src/server/game/Entities/Unit/Unit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13889,14 +13889,28 @@ void Unit::ClearBossEmotes(Optional<uint32> zoneId, Player const* target) const
ref.GetSource()->SendDirectMessage(clearBossEmotes.GetRawPacket());
}

SpellInfo const* Unit::GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag) const
bool Unit::GetCastSpellInfoContext::AddSpell(uint32 spellId)
{
auto findMatchingAuraEffectIn = [this, spellInfo, &triggerFlag](AuraType type) -> SpellInfo const*
auto itr = std::ranges::find(VisitedSpells, spellId);
if (itr != VisitedSpells.end())
return false; // already exists

itr = std::ranges::find(VisitedSpells, 0u);
if (itr == VisitedSpells.end())
return false; // no free slots left

*itr = spellId;
return true;
}

SpellInfo const* Unit::GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag, GetCastSpellInfoContext* context) const
{
auto findMatchingAuraEffectIn = [this, spellInfo, &triggerFlag, context](AuraType type) -> SpellInfo const*
{
for (AuraEffect const* auraEffect : GetAuraEffectsByType(type))
{
bool matches = auraEffect->GetMiscValue() ? uint32(auraEffect->GetMiscValue()) == spellInfo->Id : auraEffect->IsAffectingSpell(spellInfo);
if (matches)
if (matches && context->AddSpell(auraEffect->GetAmount()))
{
if (SpellInfo const* newInfo = sSpellMgr->GetSpellInfo(auraEffect->GetAmount(), GetMap()->GetDifficultyID()))
{
Expand All @@ -13921,13 +13935,13 @@ SpellInfo const* Unit::GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastF
if (SpellInfo const* newInfo = findMatchingAuraEffectIn(SPELL_AURA_OVERRIDE_ACTIONBAR_SPELLS))
{
triggerFlag &= ~TRIGGERED_IGNORE_CAST_TIME;
return GetCastSpellInfo(newInfo, triggerFlag);
return GetCastSpellInfo(newInfo, triggerFlag, context);
}

if (SpellInfo const* newInfo = findMatchingAuraEffectIn(SPELL_AURA_OVERRIDE_ACTIONBAR_SPELLS_TRIGGERED))
{
triggerFlag |= TRIGGERED_IGNORE_CAST_TIME;
return GetCastSpellInfo(newInfo, triggerFlag);
return GetCastSpellInfo(newInfo, triggerFlag, context);
}

return spellInfo;
Expand Down
7 changes: 6 additions & 1 deletion src/server/game/Entities/Unit/Unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -1452,7 +1452,12 @@ class TC_GAME_API Unit : public WorldObject
Spell* GetCurrentSpell(uint32 spellType) const { return m_currentSpells[spellType]; }
Spell* FindCurrentSpellBySpellId(uint32 spell_id) const;
int32 GetCurrentSpellCastTime(uint32 spell_id) const;
virtual SpellInfo const* GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag) const;
struct GetCastSpellInfoContext
{
std::array<uint32, 5> VisitedSpells = { };
bool AddSpell(uint32 spellId);
};
virtual SpellInfo const* GetCastSpellInfo(SpellInfo const* spellInfo, TriggerCastFlags& triggerFlag, GetCastSpellInfoContext* context) const;
uint32 GetCastSpellXSpellVisualId(SpellInfo const* spellInfo) const override;

virtual bool HasSpellFocus(Spell const* /*focusSpell*/ = nullptr) const { return false; }
Expand Down
Loading