Skip to content

Support pt2-model and cuda graph in PRARM benchmark #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 129 additions & 20 deletions train/compute/python/lib/pytorch/op_executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

from ..init_helper import get_logger

logger = get_logger()
Expand Down Expand Up @@ -61,6 +63,40 @@ def __init__(self, name: str, op: OperatorInterface, run_options: Dict[str, Any]
)
self._label_template_fwd_bwd = f"[param|{self.name}|{{op_run_id}}|{{tag}}|{ExecutionPass.FORWARD.value}_{ExecutionPass.BACKWARD.value}]"

self.op.forward = (
torch.compile(self.op.forward)
if run_options.get("pt2-model", False)
else self.op.forward
)

self.op.backward = (
torch.compile(self.op.backward)
if run_options.get("pt2-model", False)
else self.op.backward
)

self.use_cuda_graph = run_options.get("cuda-graph", False)

self.fwd_cuda_graph = None
self.bwd_cuda_graph = None

def generate_cuda_graph(self, args: List, kwargs: Dict[str, Any]):
s = torch.cuda.Stream(self.torch_device)
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
self.op.forward(*args, **kwargs)
self.op.create_grad()
self.op.backward()

self.fwd_cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.fwd_cuda_graph):
self.op.forward(*args, **kwargs)

if self.pass_type == ExecutionPass.BACKWARD:
self.bwd_cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.bwd_cuda_graph):
self.op.backward()

def run(
self, input_args: List, input_kwargs: Dict[str, Any], op_run_id: str
) -> Dict[str, Any]:
Expand All @@ -82,7 +118,12 @@ def run(
return result

def _benchmark_op(
self, op: Callable, args: List, kwargs: Dict[str, Any], tag: str, label_str: str
self,
op: Union[Callable, torch.cuda.CUDAGraph],
args: List,
kwargs: Dict[str, Any],
tag: str,
label_str: str,
) -> Tuple[float, float]:
logger.debug(f"benchmarking {label_str}")
gpu_memory = 0
Expand All @@ -99,10 +140,13 @@ def _benchmark_op(
timer.start()
if self.use_cuda:
op_run_id_range = torch.cuda.nvtx.range_start(label_str)
op(*args, **kwargs)
if isinstance(op, torch.cuda.CUDAGraph):
op.replay()
elif isinstance(op, Callable):
op(*args, **kwargs)
timer.stop()
if self.use_cuda:
torch.cuda.nvtx.range_end(op_run_id_range)
torch.cuda.nvtx.range_end(op_run_id_range) # pyre-ignore[61]:
# Memory size in MB
gpu_memory = torch.cuda.max_memory_allocated(self.torch_device) / (
1048576
Expand Down Expand Up @@ -132,29 +176,52 @@ def _benchmark_discrete(
if self.use_cuda:
tag_range = torch.cuda.nvtx.range_start(f"[param|{tag}]")

if (
tag == "warmup"
and self.use_cuda_graph
and (
self.fwd_cuda_graph is None
or (
self.pass_type == ExecutionPass.BACKWARD
and self.bwd_cuda_graph is None
)
)
):
self.generate_cuda_graph(args, kwargs)

with record_function(label_str):
for _ in range(count):
label_str = self._label_template_fwd.format(
tag=tag, op_run_id=op_run_id
)
latency, peak_memory = self._benchmark_op(
self.op.forward, args, kwargs, tag, label_str
)
if self.use_cuda_graph:
latency, peak_memory = self._benchmark_op(
self.fwd_cuda_graph, args, kwargs, tag, label_str
)
else:
latency, peak_memory = self._benchmark_op(
self.op.forward, args, kwargs, tag, label_str
)
fw_time_records.append(latency)
fw_gpu_mem_records.append(peak_memory)
if self.pass_type == ExecutionPass.BACKWARD:
self.op.create_grad()
label_str = self._label_template_bwd.format(
tag=tag, op_run_id=op_run_id
)
latency, peak_memory = self._benchmark_op(
self.op.backward, [], {}, tag, label_str
)
if self.use_cuda_graph:
latency, peak_memory = self._benchmark_op(
self.bwd_cuda_graph, [], {}, tag, label_str
)
else:
latency, peak_memory = self._benchmark_op(
self.op.backward, [], {}, tag, label_str
)
bw_time_records.append(latency)
bw_gpu_mem_records.append(peak_memory)

if self.use_cuda:
torch.cuda.nvtx.range_end(tag_range)
torch.cuda.nvtx.range_end(tag_range) # pyre-ignore[61]:
return (
fw_time_records,
fw_gpu_mem_records,
Expand All @@ -169,7 +236,7 @@ def _benchmark_loop_cuda_events(
kwargs: Dict[str, Any],
tag: str,
op_run_id: str,
) -> float:
) -> Tuple[List[float], List[float], List[float], List[float]]:
"""
Using CUDA events to record is making the assumptions that we are running single stream.
In this mode, we do not flush cache, assuming benefit from data in warmup.
Expand All @@ -192,6 +259,19 @@ def compute_cuda_event_delta(events: List[Tuple[Any]]):

return deltas

if (
tag == "warmup"
and self.use_cuda_graph
and (
self.fwd_cuda_graph is None
or (
self.pass_type == ExecutionPass.BACKWARD
and self.bwd_cuda_graph is None
)
)
):
self.generate_cuda_graph(args, kwargs)

fw_time_records = []
bw_time_records = []
fw_gpu_mem_records = []
Expand All @@ -207,7 +287,10 @@ def compute_cuda_event_delta(events: List[Tuple[Any]]):
op_run_id_range = torch.cuda.nvtx.range_start(label_str)
for i in range(count):
fw_events[i][0].record()
self.op.forward(*args, **kwargs)
if self.use_cuda_graph:
self.fwd_cuda_graph.replay()
else:
self.op.forward(*args, **kwargs)
fw_events[i][1].record()

torch.cuda.synchronize(self.torch_device)
Expand All @@ -230,9 +313,15 @@ def compute_cuda_event_delta(events: List[Tuple[Any]]):
torch.cuda.synchronize(self.torch_device)
op_run_id_range = torch.cuda.nvtx.range_start(label_str)
for i in range(count):
self.op.forward(*args, **kwargs)
if self.use_cuda_graph:
self.fwd_cuda_graph.replay()
else:
self.op.forward(*args, **kwargs)
bw_events[i][0].record()
self.op.backward()
if self.use_cuda_graph:
self.bwd_cuda_graph.replay()
else:
self.op.backward()
bw_events[i][1].record()

torch.cuda.synchronize(self.torch_device)
Expand All @@ -254,10 +343,23 @@ def _benchmark_loop_cuda(
kwargs: Dict[str, Any],
tag: str,
op_run_id: str,
) -> float:
) -> Tuple[List[float], List[float], List[float], List[float]]:

logger.debug(f"benchmarking {self.name}|{op_run_id}|{tag}")

if (
tag == "warmup"
and self.use_cuda_graph
and (
self.fwd_cuda_graph is None
or (
self.pass_type == ExecutionPass.BACKWARD
and self.bwd_cuda_graph is None
)
)
):
self.generate_cuda_graph(args, kwargs)

fw_time_records = []
bw_time_records = []
fw_gpu_mem_records = []
Expand All @@ -275,7 +377,10 @@ def _benchmark_loop_cuda(
timer.start()
op_run_id_range = torch.cuda.nvtx.range_start(label_str)
for _i in range(count):
self.op.forward(*args, **kwargs)
if self.use_cuda_graph:
self.fwd_cuda_graph.replay()
else:
self.op.forward(*args, **kwargs)
timer.stop()
torch.cuda.nvtx.range_end(op_run_id_range)

Expand All @@ -295,8 +400,12 @@ def _benchmark_loop_cuda(
timer.start()
op_run_id_range = torch.cuda.nvtx.range_start(label_str)
for _i in range(count):
self.op.forward(*args, **kwargs)
self.op.backward()
if self.use_cuda_graph:
self.fwd_cuda_graph.replay()
self.bwd_cuda_graph.replay()
else:
self.op.forward(*args, **kwargs)
self.op.backward()
timer.stop()
torch.cuda.nvtx.range_end(op_run_id_range)
# Subtract forward time to get backward time.
Expand All @@ -320,7 +429,7 @@ def _benchmark_loop_cpu(
kwargs: Dict[str, Any],
tag: str,
op_run_id: str,
) -> float:
) -> Tuple[List[float], List[float], List[float], List[float]]:
logger.debug(f"benchmarking [{self.name}|{op_run_id}|{tag}]")

fw_time_records = []
Expand Down Expand Up @@ -378,7 +487,7 @@ def _measure(
tag: str,
op_run_id: str,
result: Dict[str, Any],
) -> Dict[str, Any]:
) -> None:
logger.info(f"running [{op_run_id}] for {iteration} {tag} iteration")
fw_time_records = []
fw_mem_records = []
Expand Down
20 changes: 19 additions & 1 deletion train/compute/python/pytorch/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ def main():
choices=["on", "off"],
help="Set option for CUDA GPU L2 cache between iterations in discrete mode.",
)
parser.add_argument(
"--pt2-model",
action="store_true",
help="Compile the model before run.",
)
parser.add_argument(
"--cuda-graph",
action="store_true",
help="Enable CUDA graph.",
)
parser.add_argument(
"--ncu", action="store_true", help="Run NSight Compute to collect metrics."
)
Expand Down Expand Up @@ -195,6 +205,12 @@ def main():
args = parser.parse_args()

logger = init_logging(getattr(logging, args.log_level.upper(), logging.INFO))
if args.cuda_graph:
if "cuda" not in args.device:
logger.warning(
"Cannot use --cuda-graph when not running on cuda device, cuda-graph is disabled"
)
args.cuda_graph = False

if args.version:
logger.info(f"PARAM train compute version: {__version__}")
Expand All @@ -214,6 +230,8 @@ def main():
run_options["iteration"] = args.iteration
run_options["device"] = args.device
run_options["cuda_l2_cache"] = args.cuda_l2_cache == "on"
run_options["pt2-model"] = args.pt2_model
run_options["cuda-graph"] = args.cuda_graph
run_options["resume_op_run_id"] = args.resume_id
run_options["stop_op_run_id"] = args.stop_id
run_options["run_batch_size"] = args.run_batch_size
Expand Down Expand Up @@ -316,7 +334,7 @@ def main():
args.profile,
use_cuda=use_cuda,
use_kineto=True,
record_shapes=False,
record_shapes=True,
experimental_config=cupti_profiler_config,
# use_cpu enables profiling and recodring of CPU pytorch operators.
# This is useful in CUPTI profiler mode if we are measuring per GPU kernel metrics.
Expand Down
Loading