Skip to content

Latest commit

 

History

History
53 lines (32 loc) · 1.86 KB

README.md

File metadata and controls

53 lines (32 loc) · 1.86 KB

Tensor tracker

API documentation | Example

Flexibly track outputs and grad-outputs of torch.nn.Module.

Installation:

pip install git+https://github.com/graphcore-research/pytorch-tensor-tracker

Usage:

Use tensor_tracker.track(module) as a context manager to start capturing tensors from within your module's forward and backward passes:

import tensor_tracker

with tensor_tracker.track(module) as tracker:
    module(inputs).backward()

print(tracker)  # => Tracker(stashes=8, tracking=0)

Now Tracker is filled with stashes, containing copies of fwd/bwd tensors at (sub)module outputs. (Note, this can consume a lot of memory.)

It behaves like a list of Stash objects, with their attached value, usually a tensor or tuple of tensors. We can also use to_frame() to get a Pandas table of summary statistics:

print(list(tracker))
# => [Stash(name="0.linear", type=nn.Linear, grad=False, value=tensor(...)),
#     ...]

display(tracker.to_frame())

tensor tracker to_frame output

See the documentation for more info, or for a more practical example, see our demo of visualising transformer activations & gradients using UMAP. To use on IPU with PopTorch, please see Usage (PopTorch).

License

Copyright (c) 2023 Graphcore Ltd. Licensed under the MIT License (LICENSE).

Our dependencies are (see requirements.txt):

Component About License
torch Machine learning framework BSD 3-Clause

We also use additional Python dependencies for development/testing (see requirements-dev.txt).