Skip to content

Commit

Permalink
Address review comments and fix style issues
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Jun 27, 2024
1 parent 94ffee3 commit f1dd62f
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
4 changes: 3 additions & 1 deletion et_replay/comm/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ def _parseExecutionTrace(
ranks = pg["ranks"]
groupCnt = pg["group_count"]
pgRanksMap[pgId] = (
ranks if len(ranks) > 0 else list(range(pg["group_size"]))
ranks
if len(ranks) > 0
else list(range(pg["group_size"]))
# rank list is empty when all ranks are in a pg
)
if ET_BACKENDID:
Expand Down
4 changes: 3 additions & 1 deletion et_replay/comm/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,9 @@ def __init__(self, args: Namespace) -> None:
self.quant_a2a_embedding_dim = args.quant_a2a_embedding_dim
self.quant_threshold = args.quant_threshold
self.dcheck = args.c
self.groupRanks = {} # record what ranks each process group will work on {pg_id, ranks}
self.groupRanks = (
{}
) # record what ranks each process group will work on {pg_id, ranks}
self.use_ext_dist = args.use_ext_dist
self.size_from_trace = False
self.init_method = args.init_method
Expand Down
7 changes: 4 additions & 3 deletions et_replay/tools/et_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,11 +608,11 @@ def allocate_tensors(self):
self.tensors_mapping[
(node.id, tuple(node.inputs[2][:5]), True)
]
][i] = i * nnz
][i] = (i * nnz)
else:
self.tensor_registry_permanent[
self.tensors_mapping[(node.id, tuple(node.inputs[2]), True)]
][i] = i * nnz
][i] = (i * nnz)
######

def build_func(self, node):
Expand Down Expand Up @@ -1183,7 +1183,8 @@ def run_op(self, node, iter):
not in self.unchangeable_intermediate_tensors
):
if (
self.tensors_mapping[(node.id, t_id, False)] not in self.instantiate
self.tensors_mapping[(node.id, t_id, False)]
not in self.instantiate
# and self.tensors_mapping[(node.id, t_id, False)]
# not in self.tensor_registry
):
Expand Down
2 changes: 1 addition & 1 deletion et_replay/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from typing import Any, Dict

from et_replay.execution_trace import ExecutionTrace
from et_replay import ExecutionTrace


def get_tmp_trace_filename() -> str:
Expand Down

0 comments on commit f1dd62f

Please sign in to comment.