Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cuda_graph for mts_gpu_benchmark #1012

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fx2ait/fx2ait/ait_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _lower_model_to_backend(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interpreter_result,
)
Expand Down
3 changes: 2 additions & 1 deletion fx2ait/fx2ait/csrc/AITModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ static auto registerAITModel =
std::vector<std::string>,
std::optional<at::ScalarType>,
std::optional<at::ScalarType>,
int64_t>())
int64_t,
bool>())
.def("forward", &AITModel::forward)
.def("profile", &AITModel::profile)
.def("get_library_path", &AITModel::libraryPath)
Expand Down
4 changes: 3 additions & 1 deletion fx2ait/fx2ait/csrc/AITModelImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ AITModelImpl::AITModelImpl(
floating_point_input_dtype_(input_dtype),
floating_point_output_dtype_(output_dtype),
use_cuda_graph_(use_cuda_graph) {
LOG(INFO) << "Loading .so lib " << model_path;
LOG(INFO) << "AITModelImpl: loading .so lib " << model_path;
LOG(INFO) << "AITModelImpl: num_runtimes: " << num_runtimes
<< ",use_cuda_graph: " << use_cuda_graph;
TORCH_CHECK(handle_, "could not dlopen ", model_path, ": ", dlerror());
TORCH_CHECK(num_runtimes > 0, "num_runtimes must be positive");

Expand Down
1 change: 1 addition & 0 deletions fx2ait/fx2ait/lower/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def lower_pass(
_precision_to_torch_type(lower_settings.precision),
_precision_to_torch_type(lower_settings.output_precision),
1, # num_runtimes
False,
),
interp_res,
lower_settings.trace_ait_module,
Expand Down
2 changes: 2 additions & 0 deletions fx2ait/fx2ait/test/test_fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _test_fx2ait_impl(self, test_serialization=False, test_cuda_graph=False):
torch.float16,
torch.float16,
1, # num_runtimes
False,
)
)
ait_mod.engine.use_cuda_graph = test_cuda_graph
Expand Down Expand Up @@ -140,6 +141,7 @@ def forward(self, a, b, c, d):
torch.float16,
torch.float16,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down
1 change: 1 addition & 0 deletions fx2ait/fx2ait/tools/ait_minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def lower_mod_default(
torch.float16,
torch.float16,
1, # num_runtimes
False,
),
interpreter_result,
)
Expand Down
3 changes: 3 additions & 0 deletions fx2ait/fx2ait/tools/common_aten2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def run_test(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down Expand Up @@ -256,6 +257,7 @@ def run_test_with_dynamic_shape(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down Expand Up @@ -375,6 +377,7 @@ def benchmark(f, args):
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down
5 changes: 5 additions & 0 deletions fx2ait/fx2ait/tools/common_fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def run_test(
torch_dtype,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand All @@ -199,6 +200,7 @@ def run_test(
torch_dtype,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down Expand Up @@ -317,6 +319,7 @@ def run_test_with_dynamic_shape(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand All @@ -329,6 +332,7 @@ def run_test_with_dynamic_shape(
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down Expand Up @@ -467,6 +471,7 @@ def benchmark(f, args):
torch.float16,
torch.float,
1, # num_runtimes
False,
),
interp_result,
)
Expand Down
Loading