Skip to content

Commit

Permalink
🐛 [HotFix] Handle Profiler Activities Based on PyTorch Version (#3136)
Browse files Browse the repository at this point in the history
  • Loading branch information
yhna940 authored Oct 7, 2024
1 parent e93b056 commit cd93e35
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .constants import (
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
MODEL_NAME,
OPTIMIZER_NAME,
PROFILE_PATTERN_NAME,
Expand All @@ -28,6 +29,7 @@
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
WEIGHTS_PATTERN_NAME,
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION,
)
from .dataclasses import (
AutocastKwargs,
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich"]
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"

STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}

Expand Down
18 changes: 14 additions & 4 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@

import torch

from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY
from .constants import (
FSDP_AUTO_WRAP_POLICY,
FSDP_BACKWARD_PREFETCH,
FSDP_SHARDING_STRATEGY,
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION,
)
from .environment import parse_flag_from_env, str_to_bool
from .imports import (
is_cuda_available,
Expand All @@ -39,7 +45,7 @@
is_transformer_engine_available,
is_xpu_available,
)
from .versions import compare_versions
from .versions import compare_versions, is_torch_version


class KwargsHandler:
Expand Down Expand Up @@ -468,11 +474,15 @@ def _get_profiler_activity(self, activity: ProfilerActivity) -> torch.profiler.P

profiler_activity_map: dict[str, torch.profiler.ProfilerActivity] = {
"cpu": torch.profiler.ProfilerActivity.CPU,
"xpu": torch.profiler.ProfilerActivity.XPU,
"mita": torch.profiler.ProfilerActivity.MTIA,
"cuda": torch.profiler.ProfilerActivity.CUDA,
}

if is_torch_version(">=", XPU_PROFILING_AVAILABLE_PYTORCH_VERSION):
profiler_activity_map["xpu"] = torch.profiler.ProfilerActivity.XPU

if is_torch_version(">=", MITA_PROFILING_AVAILABLE_PYTORCH_VERSION):
profiler_activity_map["mtia"] = torch.profiler.ProfilerActivity.MTIA

if activity not in profiler_activity_map:
raise ValueError(f"Invalid profiler activity: {activity}. Must be one of {list(profiler_activity_map)}.")
return profiler_activity_map[activity]
Expand Down

0 comments on commit cd93e35

Please sign in to comment.