Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 13, 2024
1 parent 06af9d2 commit 8e3bb88
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
12 changes: 10 additions & 2 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
val[2] = N

@staticmethod
def print(prefix=None): # noqa: T202
def print(prefix=None) -> str: # noqa: T202
"""Prints the state of the timer.
Returns:
the string printed using the logger.
"""
keys = list(timeit._REG)
keys.sort()
string = []
for name in keys:
strings = []
if prefix:
strings.append(prefix)
strings.append(
f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)"
)
logger.info(" -- ".join(strings))
string.append(" -- ".join(strings))
logger.info(string[-1])
return "\n".join(string)

@classmethod
def todict(cls, percall=True):
Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@
print("pkg_path", pkg_path)

# Print the structor of our temporary directory, including file size
tensordict.utils.print_directory_tree(tmpdir)
print(tensordict.utils.print_directory_tree(tmpdir))

compiled_module = aoti_load_package(pkg_path)

Expand Down Expand Up @@ -463,7 +463,7 @@ def onnx_policy(screen_obs: np.ndarray) -> int:
with timeit("TorchRL version"), torch.no_grad(), set_exploration_type("DETERMINISTIC"):
env.rollout(num_steps, policy_explore)

timeit.print()
print(timeit.print())

#####################################
# Note that ONNX also offers the possibility of optimizing models directly, but this is beyond the scope of this
Expand Down

0 comments on commit 8e3bb88

Please sign in to comment.