From 6ce70465b3a140b6382d3ba25ca626703b67d0f3 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Tue, 5 Nov 2024 08:33:56 -0600 Subject: [PATCH] Add trace_tensors op We don't have an op that dispatches to the underlying iree.turbine.ops.iree.trace_tensor. This sets the default tracing callback to use the existing tracing functionality where we use safetensors. --- sharktank/pyproject.toml | 2 +- sharktank/requirements-tests.txt | 1 + sharktank/sharktank/ops/default_impls.py | 8 +++++ sharktank/sharktank/ops/signatures.py | 18 ++++++++++ sharktank/sharktank/utils/debugging.py | 22 +++++++++++- sharktank/tests/ops/ops_test.py | 44 ++++++++++++++++++++++++ 6 files changed, 93 insertions(+), 2 deletions(-) diff --git a/sharktank/pyproject.toml b/sharktank/pyproject.toml index 909977ea7..c29b04292 100644 --- a/sharktank/pyproject.toml +++ b/sharktank/pyproject.toml @@ -23,7 +23,7 @@ dependencies = ["iree-turbine"] dynamic = ["version"] # the version is set via the `setup.py` [project.optional-dependencies] -testing = ["pytest"] +testing = ["pytest", "safetensors"] [project.urls] Repository = "https://github.com/nod-ai/SHARK-Platform" diff --git a/sharktank/requirements-tests.txt b/sharktank/requirements-tests.txt index d5b4b0c0e..8371def85 100644 --- a/sharktank/requirements-tests.txt +++ b/sharktank/requirements-tests.txt @@ -2,3 +2,4 @@ datasets==3.0.0 parameterized pytest==8.0.0 pytest-html +safetensors==0.4.5 diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 92fe03a31..142eae6b7 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -25,6 +25,7 @@ from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType from .signatures import * import iree.turbine.ops.iree +from ..utils import debugging @cat.override(AllOfType(Tensor, PrimitiveTensor)) @@ -436,6 +437,13 @@ def to_default(tensor: Tensor, *args, **kwargs): return unbox_tensor(tensor).to(*args, **kwargs) +@trace_tensors.override(AllOfExprsVariadic(IsOfType(Tensor, InferenceTensor))) +def trace_tensors(key: str, *tensors: tuple[AnyTensor]): + if len(tensors) != 1: + raise ValueError("Tracing more than one tensor at a time is not supported.") + iree.turbine.ops.iree.trace_tensor(key, unshard(tensors[0])) + + @transfer_to_logical_device.override(Tensor) def transfer_to_logical_device_default(tensor: Tensor, ordinal: int): return iree.turbine.ops.iree.transfer_to_logical_device( diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index d9002ce37..1eae6fdab 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -53,6 +53,7 @@ "sharded_sum", "softmax", "to", + "trace_tensors", "transfer_to_logical_device", "transpose", "unflatten", @@ -977,6 +978,23 @@ def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs): d.fail(dispatch_args) +@overridable +def trace_tensors(key: str, *tensors: tuple[AnyTensor]): + ... + + +@trace_tensors.trampoline +def _transfer_to_logical_device_trampoline( + d: SignatureDispatcher, key: str, *tensors: tuple[AnyTensor] +): + for override in d.find_overrides(tensors): + result = override(key, *tensors) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor: """Transfer the tensor to a device with ordinal `ordinal`.""" diff --git a/sharktank/sharktank/utils/debugging.py b/sharktank/sharktank/utils/debugging.py index 1371732e2..3e19b4056 100644 --- a/sharktank/sharktank/utils/debugging.py +++ b/sharktank/sharktank/utils/debugging.py @@ -5,13 +5,14 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """Tools for debugging models.""" -from typing import Dict, Optional +from typing import Callable, Dict, Optional, Tuple from dataclasses import dataclass import re import os from pathlib import Path from typing import Sequence +import iree.turbine.support.debugging import torch @@ -82,6 +83,25 @@ def parse_from_env() -> "DebugFlags": flags = DebugFlags.parse_from_env() +TraceKey = str +TraceTensors = Callable[[TraceKey, *Tuple[torch.Tensor, ...]], None] + + +def set_trace_tensors_callback(callback: TraceTensors): + iree.turbine.support.debugging.trace_tensor_callback = callback + + +def get_trace_tensors_callback() -> Optional[TraceTensors]: + return iree.turbine.support.debugging.trace_tensor_callback + + +def default_trace_tensors_callback(key: str, *tensors: Tuple[torch.Tensor]): + tensors_in_dict = {f"{i}": t for i, t in enumerate(tensors)} + trace_tensors(key, tensors_in_dict, values=False, golden=True) + + +set_trace_tensors_callback(default_trace_tensors_callback) + def trace_tensor( key: str, t: torch.Tensor, *, values: bool = True, golden: bool = False diff --git a/sharktank/tests/ops/ops_test.py b/sharktank/tests/ops/ops_test.py index cd55b23d0..5299a08e2 100644 --- a/sharktank/tests/ops/ops_test.py +++ b/sharktank/tests/ops/ops_test.py @@ -12,6 +12,9 @@ from sharktank import ops from sharktank.types import * +from sharktank.utils import debugging +from sharktank.utils.testing import TempDirTestBase +import safetensors class BroadcastDimsTest(unittest.TestCase): @@ -291,5 +294,46 @@ def forward(self, a, d, qs, m): self.assertIn("mmt_block_scaled_offset_q4_unsigned.default", s) +class TestTraceTensors(TempDirTestBase): + def testTraceOneTensorInEagerMode(self): + save_goldens_path_stash = debugging.flags.save_goldens_path + debugging.flags.save_goldens_path = self._temp_dir + debugging.flags.golden_sequence_value = 0 + + tensor = torch.arange(1, 5) + trace_key = "test_trace_key" + ops.trace_tensors(trace_key, tensor) + + trace_filepath = ( + debugging.flags.save_goldens_path / f"0000_{trace_key}.safetensors" + ) + with safetensors.safe_open(trace_filepath, framework="pt", device="cpu") as f: + assert len(f.keys()) == 1 + recorded_tensor = f.get_tensor("0") + torch.testing.assert_close(recorded_tensor, tensor, rtol=0, atol=0) + + debugging.flags.save_goldens_path = save_goldens_path_stash + + def testTraceOneShardedTensorInEagerMode(self): + save_goldens_path_stash = debugging.flags.save_goldens_path + debugging.flags.save_goldens_path = self._temp_dir + debugging.flags.golden_sequence_value = 0 + + tensor = torch.arange(1, 6) + sharded_tensor = ops.reshard_split(tensor, count=2, dim=0) + trace_key = "test_trace_key" + ops.trace_tensors(trace_key, sharded_tensor) + + trace_filepath = ( + debugging.flags.save_goldens_path / f"0000_{trace_key}.safetensors" + ) + with safetensors.safe_open(trace_filepath, framework="pt", device="cpu") as f: + assert len(f.keys()) == 1 + recorded_tensor = f.get_tensor("0") + torch.testing.assert_close(recorded_tensor, tensor, rtol=0, atol=0) + + debugging.flags.save_goldens_path = save_goldens_path_stash + + if __name__ == "__main__": unittest.main()