Skip to content

Commit

Permalink
Add trace_tensors op
Browse files Browse the repository at this point in the history
We don't have an op that dispatches to the underlying
iree.turbine.ops.iree.trace_tensor.

This sets the default tracing callback to use the existing tracing
functionality where we use safetensors.
  • Loading branch information
sogartar committed Nov 5, 2024
1 parent 1c6800b commit 6ce7046
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 2 deletions.
2 changes: 1 addition & 1 deletion sharktank/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = ["iree-turbine"]
dynamic = ["version"] # the version is set via the `setup.py`

[project.optional-dependencies]
testing = ["pytest"]
testing = ["pytest", "safetensors"]

[project.urls]
Repository = "https://github.com/nod-ai/SHARK-Platform"
Expand Down
1 change: 1 addition & 0 deletions sharktank/requirements-tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ datasets==3.0.0
parameterized
pytest==8.0.0
pytest-html
safetensors==0.4.5
8 changes: 8 additions & 0 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType
from .signatures import *
import iree.turbine.ops.iree
from ..utils import debugging


@cat.override(AllOfType(Tensor, PrimitiveTensor))
Expand Down Expand Up @@ -436,6 +437,13 @@ def to_default(tensor: Tensor, *args, **kwargs):
return unbox_tensor(tensor).to(*args, **kwargs)


@trace_tensors.override(AllOfExprsVariadic(IsOfType(Tensor, InferenceTensor)))
def trace_tensors(key: str, *tensors: tuple[AnyTensor]):
if len(tensors) != 1:
raise ValueError("Tracing more than one tensor at a time is not supported.")
iree.turbine.ops.iree.trace_tensor(key, unshard(tensors[0]))


@transfer_to_logical_device.override(Tensor)
def transfer_to_logical_device_default(tensor: Tensor, ordinal: int):
return iree.turbine.ops.iree.transfer_to_logical_device(
Expand Down
18 changes: 18 additions & 0 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"sharded_sum",
"softmax",
"to",
"trace_tensors",
"transfer_to_logical_device",
"transpose",
"unflatten",
Expand Down Expand Up @@ -977,6 +978,23 @@ def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs):
d.fail(dispatch_args)


@overridable
def trace_tensors(key: str, *tensors: tuple[AnyTensor]):
...


@trace_tensors.trampoline
def _transfer_to_logical_device_trampoline(
d: SignatureDispatcher, key: str, *tensors: tuple[AnyTensor]
):
for override in d.find_overrides(tensors):
result = override(key, *tensors)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)


@overridable
def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
"""Transfer the tensor to a device with ordinal `ordinal`."""
Expand Down
22 changes: 21 additions & 1 deletion sharktank/sharktank/utils/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Tools for debugging models."""
from typing import Dict, Optional
from typing import Callable, Dict, Optional, Tuple

from dataclasses import dataclass
import re
import os
from pathlib import Path
from typing import Sequence
import iree.turbine.support.debugging

import torch

Expand Down Expand Up @@ -82,6 +83,25 @@ def parse_from_env() -> "DebugFlags":

flags = DebugFlags.parse_from_env()

TraceKey = str
TraceTensors = Callable[[TraceKey, *Tuple[torch.Tensor, ...]], None]


def set_trace_tensors_callback(callback: TraceTensors):
iree.turbine.support.debugging.trace_tensor_callback = callback


def get_trace_tensors_callback() -> Optional[TraceTensors]:
return iree.turbine.support.debugging.trace_tensor_callback


def default_trace_tensors_callback(key: str, *tensors: Tuple[torch.Tensor]):
tensors_in_dict = {f"{i}": t for i, t in enumerate(tensors)}
trace_tensors(key, tensors_in_dict, values=False, golden=True)


set_trace_tensors_callback(default_trace_tensors_callback)


def trace_tensor(
key: str, t: torch.Tensor, *, values: bool = True, golden: bool = False
Expand Down
44 changes: 44 additions & 0 deletions sharktank/tests/ops/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

from sharktank import ops
from sharktank.types import *
from sharktank.utils import debugging
from sharktank.utils.testing import TempDirTestBase
import safetensors


class BroadcastDimsTest(unittest.TestCase):
Expand Down Expand Up @@ -291,5 +294,46 @@ def forward(self, a, d, qs, m):
self.assertIn("mmt_block_scaled_offset_q4_unsigned.default", s)


class TestTraceTensors(TempDirTestBase):
def testTraceOneTensorInEagerMode(self):
save_goldens_path_stash = debugging.flags.save_goldens_path
debugging.flags.save_goldens_path = self._temp_dir
debugging.flags.golden_sequence_value = 0

tensor = torch.arange(1, 5)
trace_key = "test_trace_key"
ops.trace_tensors(trace_key, tensor)

trace_filepath = (
debugging.flags.save_goldens_path / f"0000_{trace_key}.safetensors"
)
with safetensors.safe_open(trace_filepath, framework="pt", device="cpu") as f:
assert len(f.keys()) == 1
recorded_tensor = f.get_tensor("0")
torch.testing.assert_close(recorded_tensor, tensor, rtol=0, atol=0)

debugging.flags.save_goldens_path = save_goldens_path_stash

def testTraceOneShardedTensorInEagerMode(self):
save_goldens_path_stash = debugging.flags.save_goldens_path
debugging.flags.save_goldens_path = self._temp_dir
debugging.flags.golden_sequence_value = 0

tensor = torch.arange(1, 6)
sharded_tensor = ops.reshard_split(tensor, count=2, dim=0)
trace_key = "test_trace_key"
ops.trace_tensors(trace_key, sharded_tensor)

trace_filepath = (
debugging.flags.save_goldens_path / f"0000_{trace_key}.safetensors"
)
with safetensors.safe_open(trace_filepath, framework="pt", device="cpu") as f:
assert len(f.keys()) == 1
recorded_tensor = f.get_tensor("0")
torch.testing.assert_close(recorded_tensor, tensor, rtol=0, atol=0)

debugging.flags.save_goldens_path = save_goldens_path_stash


if __name__ == "__main__":
unittest.main()

0 comments on commit 6ce7046

Please sign in to comment.