Skip to content

Commit 6694fda

Browse files
Elias Ellisonpytorchmergebot
Elias Ellison
authored andcommitted
Clean up profiling mode and profiling executor strategy (pytorch#73875)
Summary: Pull Request resolved: pytorch#73875 Previously we had a few settings: - getExecutor - which toggled between Profiling Executor and Legacy - getGraphOptimize - if true, overrides PE/Legacy to run with simple executor (no optimizations) and then... - getProfilingMode - which would set PE to 0 specializtions. The last mode is redundant with getGraphOptimize, we should just remove it and use getGraphOptimize in these cases. It would lead to potentially invalid combinations of logic - what does mean if getProfilingMode is true but getExecutor is set to false ? This would lead to a bug in specialize_autograd_zero in this case, see: https://github.com/pytorch/pytorch/blob/master/torch%2Fcsrc%2Fjit%2Fpasses%2Fspecialize_autogradzero.cpp#L93. The tests here are failing but get fixed with the PR above it, so i'll squash for landing. Test Plan: Imported from OSS Reviewed By: cpuhrsch Differential Revision: D34938130 Pulled By: eellison fbshipit-source-id: 1a9c0ae7f6d1cfddc2ed3499a5af611053ae5e1b (cherry picked from commit cf69ce3)
1 parent ab57876 commit 6694fda

22 files changed

+93
-91
lines changed

aten/src/ATen/core/builtin_function.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ struct BuiltinOpFunction : public Function {
6262
return *this;
6363
}
6464

65-
bool call(Stack& stack, size_t, c10::function_ref<void(const Code&)>) override {
65+
bool call(Stack& stack, c10::optional<size_t>, c10::function_ref<void(const Code&)>) override {
6666
run(stack);
6767
return false;
6868
}

aten/src/ATen/core/function.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct TORCH_API Function {
9090
// call() returns false.
9191

9292
// Overload for server interpreter, a bailout size is needed for graph executor.
93-
virtual bool call(Stack&, size_t, c10::function_ref<void(const Code&)>) {
93+
virtual bool call(Stack&, c10::optional<size_t>, c10::function_ref<void(const Code&)>) {
9494
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
9595
return false;
9696
}

benchmarks/fastrnns/fuser.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@ def set_fuser(fuser_name, executor_name):
44
assert fuser_name in ['te', 'old', 'none', 'default']
55
if fuser_name == 'te':
66
torch._C._jit_set_profiling_executor(True)
7-
torch._C._jit_set_profiling_mode(True)
7+
torch._C._get_graph_executor_optimize(True)
88
torch._C._jit_override_can_fuse_on_cpu(False)
99
torch._C._jit_override_can_fuse_on_gpu(True)
1010
torch._C._jit_set_texpr_fuser_enabled(True)
1111
elif fuser_name == 'old':
1212
torch._C._jit_set_profiling_executor(False)
13-
torch._C._jit_set_profiling_mode(False)
13+
torch._C._get_graph_executor_optimize(False)
1414
torch._C._jit_override_can_fuse_on_gpu(True)
1515
torch._C._jit_set_texpr_fuser_enabled(False)
1616
elif fuser_name == 'none':
1717
torch._C._jit_set_profiling_executor(False)
18-
torch._C._jit_set_profiling_mode(False)
18+
torch._C._get_graph_executor_optimize(False)
1919
torch._C._jit_override_can_fuse_on_gpu(False)
2020
torch._C._jit_override_can_fuse_on_cpu(False)
2121
torch._C._jit_set_texpr_fuser_enabled(False)
@@ -25,12 +25,11 @@ def set_fuser(fuser_name, executor_name):
2525
# --executor overrides settings of --fuser
2626
if executor_name == 'profiling':
2727
torch._C._jit_set_profiling_executor(True)
28-
torch._C._jit_set_profiling_mode(True)
28+
torch._C._get_graph_executor_optimize(True)
2929
elif executor_name == 'simple':
30-
torch._C._jit_set_profiling_executor(True)
31-
torch._C._jit_set_profiling_mode(False)
30+
torch._C._get_graph_executor_optimize(False)
3231
elif executor_name == 'legacy':
3332
torch._C._jit_set_profiling_executor(False)
34-
torch._C._jit_set_profiling_mode(False)
33+
torch._C._get_graph_executor_optimize(True)
3534
elif executor_name == 'default':
3635
pass

benchmarks/tensorexpr/__main__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def main():
137137
torch._C._jit_set_profiling_executor(True)
138138
torch._C._jit_set_texpr_fuser_enabled(True)
139139
torch._C._jit_override_can_fuse_on_gpu(True)
140-
torch._C._jit_set_profiling_mode(True)
140+
torch._C._get_graph_executor_optimize(True)
141141
elif args.cuda_fuser == "old":
142142
import torch
143143
torch._C._jit_set_profiling_executor(False)
@@ -148,7 +148,7 @@ def main():
148148
torch._C._jit_set_profiling_executor(True)
149149
torch._C._jit_set_texpr_fuser_enabled(False)
150150
torch._C._jit_set_nvfuser_enabled(True)
151-
torch._C._jit_set_profiling_mode(True)
151+
torch._C._get_graph_executor_optimize(True)
152152
else :
153153
raise ValueError("Undefined fuser: {}".format(args.cuda_fuser))
154154

test/cpp/jit/test_autodiff.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -289,14 +289,11 @@ class AutodiffRemoveUnusedGradientsTest : public ::testing::Test {
289289
void SetUp() override {
290290
prev_exec = getExecutorMode();
291291
getExecutorMode() = true;
292-
prev_profiling = getProfilingMode();
293-
getProfilingMode() = true;
294292
prev_inline_autodiff = getAutodiffSubgraphInlining();
295293
debugSetAutodiffSubgraphInlining(false);
296294
}
297295
void TearDown() override {
298296
getExecutorMode() = prev_exec;
299-
getProfilingMode() = prev_profiling;
300297
debugSetAutodiffSubgraphInlining(prev_inline_autodiff);
301298
}
302299

test/jit/test_profiler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
class TestProfiler(JitTestCase):
1919
def setUp(self):
2020
self.prev_exec = torch._C._jit_set_profiling_executor(True)
21-
self.prev_profiling = torch._C._jit_set_profiling_mode(True)
21+
self.prev_profiling = torch._C._get_graph_executor_optimize(True)
2222
self.inline_autodiff = torch._C._debug_set_autodiff_subgraph_inlining(False)
2323
self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
2424
self.can_fuse_on_cpu = torch._C._jit_can_fuse_on_cpu()
@@ -34,7 +34,7 @@ def setUp(self):
3434

3535
def tearDown(self):
3636
torch._C._jit_set_profiling_executor(self.prev_exec)
37-
torch._C._jit_set_profiling_mode(self.prev_profiling)
37+
torch._C._get_graph_executor_optimize(self.prev_profiling)
3838
torch._C._debug_set_autodiff_subgraph_inlining(self.inline_autodiff)
3939
torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
4040
torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu)

test/test_jit.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,6 @@ def doAutodiffCheck(testname):
204204
# TODO: enable TE in PE when all tests are fixed
205205
torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
206206
torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
207-
# even though FULL_PROFILER should be our default
208-
# we haven't tested every single test in this file
209-
# but we enable FULL_PROFILER for a large subset
210-
# of the tests with "with enable_profiling_mode_for_profiling_tests"
211-
torch._C._jit_set_profiling_mode(False)
212207

213208
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
214209
hx, cx = hidden
@@ -7360,7 +7355,7 @@ def test_as_tensor_tensor_input(input):
73607355
g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4))
73617356
FileCheck().check("Tensor = aten::as_tensor").check("Float(*, *, requires_grad=0, device=cpu) = aten::as_tensor").run(g)
73627357

7363-
7358+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "testing legacy behavior")
73647359
def test_tensor_requires_grad(self):
73657360
@torch.jit.script
73667361
def test(b):

test/test_jit_fuser_te.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# inferred erroneously runs or skips
1919
# some tests
2020
torch._C._jit_set_profiling_executor(True)
21-
torch._C._jit_set_profiling_mode(True)
21+
torch._C._get_graph_executor_optimize(True)
2222

2323
from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, \
2424
enable_profiling_mode_for_profiling_tests, slowTest
@@ -2608,7 +2608,7 @@ def setUp(self):
26082608
torch._C._jit_override_can_fuse_on_gpu(True)
26092609

26102610
self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
2611-
self.old_profiling_mode = torch._C._jit_set_profiling_mode(True)
2611+
self.old_profiling_mode = torch._C._get_graph_executor_optimize(True)
26122612

26132613
self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
26142614
torch._C._debug_set_fusion_group_inlining(False)
@@ -2625,7 +2625,7 @@ def setUp(self):
26252625

26262626
def tearDown(self):
26272627
torch._C._jit_set_profiling_executor(self.old_profiling_executor)
2628-
torch._C._jit_set_profiling_mode(self.old_profiling_mode)
2628+
torch._C._get_graph_executor_optimize(self.old_profiling_mode)
26292629

26302630
torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
26312631
torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)

torch/_C/__init__.pyi.in

+1-1
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ...
283283
def _get_mobile_model_contained_types(filename: Union[str, Path]): ...
284284
def _get_mobile_model_contained_types_from_buffer(buffer: BinaryIO): ...
285285
def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ...
286-
def _get_graph_executor_optimize() -> _bool: ...
286+
def _get_graph_executor_optimize(optimize: Optional[_bool] = None) -> _bool: ...
287287
def _set_graph_executor_optimize(optimize: _bool): ...
288288
def _export_opnames(module: ScriptModule) -> List[str]: ...
289289
def _create_function_from_trace(

torch/csrc/jit/api/function_impl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ struct TORCH_API GraphFunction : public Function {
9999
using Function::call;
100100
bool call(
101101
Stack& stack,
102-
size_t bailOut,
102+
c10::optional<size_t> bailOut,
103103
c10::function_ref<void(const Code&)> f) override {
104104
f(get_executor().getPlanFor(stack, bailOut).code);
105105
return true;

torch/csrc/jit/passes/specialize_autogradzero.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct AutogradZeroSpecializer {
9090
if (!isBackwardGraph()) {
9191
return;
9292
}
93-
if (getProfilingMode()) {
93+
if (getExecutorMode()) {
9494
if (auto versioning_if = guardSpecializations()) {
9595
specializeAutogradOps(versioning_if->blocks()[0]);
9696
GRAPH_DUMP("After versioning graph", graph_);

torch/csrc/jit/python/script_init.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -2009,7 +2009,16 @@ void initJitScriptBindings(PyObject* module) {
20092009
setGraphExecutorOptimize(optimize);
20102010
});
20112011

2012-
m.def("_get_graph_executor_optimize", &torch::jit::getGraphExecutorOptimize);
2012+
m.def(
2013+
"_get_graph_executor_optimize",
2014+
[](c10::optional<bool> new_setting = c10::nullopt) {
2015+
bool old_value = getGraphExecutorOptimize();
2016+
if (new_setting) {
2017+
setGraphExecutorOptimize(*new_setting);
2018+
}
2019+
return old_value;
2020+
},
2021+
py::arg("new_settings") = nullptr);
20132022

20142023
m.def(
20152024
"_enable_mobile_interface_call_export",

torch/csrc/jit/runtime/graph_executor.cpp

+15-23
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
#include <torch/csrc/autograd/edge.h>
4444
#include <torch/csrc/autograd/function.h>
45+
#include <torch/csrc/jit/python/update_graph_executor_opt.h>
4546
#include <torch/csrc/jit/runtime/logging.h>
4647

4748
#include <cstdint>
@@ -56,17 +57,16 @@ namespace torch {
5657
namespace jit {
5758

5859
EnableProfilingGuard::EnableProfilingGuard() {
59-
auto& profiling_mode = getProfilingMode();
60-
old_profiling_mode = profiling_mode;
61-
profiling_mode = true;
6260
auto& executor_mode = getExecutorMode();
6361
old_executor_mode = executor_mode;
6462
executor_mode = true;
63+
old_get_optimize = getGraphExecutorOptimize();
64+
setGraphExecutorOptimize(true);
6565
}
6666

6767
EnableProfilingGuard::~EnableProfilingGuard() {
68-
getProfilingMode() = old_profiling_mode;
6968
getExecutorMode() = old_executor_mode;
69+
setGraphExecutorOptimize(old_get_optimize);
7070
}
7171

7272
namespace {
@@ -408,8 +408,7 @@ struct DifferentiableGraphOp {
408408

409409
detachVariables(stack);
410410
if (IsNewExecutorEnabled()) {
411-
const ExecutionPlan& plan =
412-
f_ptr->getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts());
411+
const ExecutionPlan& plan = f_ptr->getPlanFor(stack);
413412
InterpreterState(plan.code).run(stack);
414413
} else {
415414
InterpreterState(legacy_f).run(stack);
@@ -550,8 +549,7 @@ void GraphExecutorImplBase::run(Stack& stack) {
550549
logging::getLogger()->addStatValue(
551550
logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0);
552551

553-
const ExecutionPlan& plan =
554-
getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts());
552+
const ExecutionPlan& plan = getPlanFor(stack);
555553
InterpreterState(plan.code).run(stack);
556554
last_executed_optimized_graph = plan.graph;
557555
}
@@ -576,9 +574,8 @@ c10::intrusive_ptr<Future> GraphExecutorImplBase::runAsync(
576574
ExecutionPlan plan;
577575
InterpreterState state;
578576
};
579-
auto frame = std::make_shared<Frame>(
580-
getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()),
581-
std::move(taskLauncher));
577+
auto frame =
578+
std::make_shared<Frame>(getPlanFor(stack), std::move(taskLauncher));
582579
auto res = frame->state.runAsync(stack);
583580
last_executed_optimized_graph = frame->plan.graph;
584581
if (!res->completed()) {
@@ -603,8 +600,9 @@ struct GraphExecutorImpl : public GraphExecutorImplBase {
603600
logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
604601
}
605602

606-
const ExecutionPlan& getPlanFor(Stack& stack, size_t remaining_bailout_depth)
607-
override {
603+
const ExecutionPlan& getPlanFor(
604+
Stack& stack,
605+
c10::optional<size_t> remaining_bailout_depth) override {
608606
return getGraphExecutorOptimize() ? getOrCompile(stack)
609607
: getOrCompileFallback();
610608
}
@@ -783,13 +781,9 @@ c10::intrusive_ptr<Future> GraphExecutor::runAsync(
783781
return pImpl->runAsync(stack, std::move(taskLauncher));
784782
}
785783

786-
size_t GraphExecutor::getDefaultNumBailOuts() {
787-
return getProfilingMode() ? getBailoutDepth() : 0;
788-
}
789-
790784
const ExecutionPlan& GraphExecutor::getPlanFor(
791785
Stack& inputs,
792-
size_t remaining_bailout_depth) {
786+
c10::optional<size_t> remaining_bailout_depth) {
793787
return pImpl->getPlanFor(inputs, remaining_bailout_depth);
794788
}
795789

@@ -887,10 +881,8 @@ void runNondiffOptimization(
887881

888882
// decomposition pass, decompose certain ops that will be used in the
889883
// following passes (like batchmm and jit fusion)
890-
if (!getProfilingMode()) {
891-
DecomposeOps(graph);
892-
GRAPH_DEBUG("After DecomposeOps\n", *graph);
893-
}
884+
DecomposeOps(graph);
885+
GRAPH_DEBUG("After DecomposeOps\n", *graph);
894886

895887
// TupleConstruct / TupleUnpack pairs can still be present at this point
896888
// and must be removed for fusion.
@@ -901,7 +893,7 @@ void runNondiffOptimization(
901893
BatchMM(graph);
902894

903895
GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph);
904-
if (getProfilingMode()) {
896+
if (getExecutorMode()) {
905897
if (tensorExprFuserEnabled()) {
906898
auto min_size = getFusionGroupInlining() ? 2 : 1;
907899
auto dyn_shapes = tensorExprDynamicShapeFusionEnabled();

torch/csrc/jit/runtime/graph_executor.h

+8-12
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,8 @@ struct Code;
1818

1919
struct ExecutionPlan {
2020
ExecutionPlan() = default;
21-
ExecutionPlan(
22-
std::shared_ptr<Graph> graph,
23-
std::string function_name,
24-
size_t remaining_bailout_depth = 0)
25-
: code(graph, std::move(function_name), remaining_bailout_depth),
26-
graph(std::move(graph)) {}
21+
ExecutionPlan(std::shared_ptr<Graph> graph, std::string function_name)
22+
: code(graph, std::move(function_name)), graph(std::move(graph)) {}
2723

2824
operator bool() const {
2925
return static_cast<bool>(graph);
@@ -34,8 +30,8 @@ struct ExecutionPlan {
3430
};
3531

3632
// Notice that those structs don't manage lifetime of their members.
37-
// They is only valid only right after you call getDebugState() and should never
38-
// be used again once another GraphExecutor function is called.
33+
// They are only valid only right after you call getDebugState() and should
34+
// never be used again once another GraphExecutor function is called.
3935

4036
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
4137
struct GraphExecutorState {
@@ -50,7 +46,7 @@ struct TORCH_API EnableProfilingGuard {
5046

5147
private:
5248
bool old_executor_mode = false;
53-
bool old_profiling_mode = false;
49+
bool old_get_optimize = false;
5450
};
5551

5652
struct GraphExecutorImplBase;
@@ -72,13 +68,13 @@ struct TORCH_API GraphExecutor {
7268
// profiled information whenever a bailout check is failed/triggered, a new
7369
// `GraphExecutor` will be created. This new `GraphExecutor`'s
7470
// remaining_bailout_depth will be reduced by 1.
71+
// If no bailout depth is passed, the depth will be initialized from the
72+
// current global fusion strategy settings.
7573
const ExecutionPlan& getPlanFor(
7674
Stack& inputs,
77-
size_t remaining_bailout_depth);
75+
c10::optional<size_t> remaining_bailout_depth = c10::nullopt);
7876
GraphExecutorState getDebugState();
7977

80-
static size_t getDefaultNumBailOuts();
81-
8278
void debugFlushCompilationCache();
8379

8480
bool isOptimized() const;

torch/csrc/jit/runtime/graph_executor_impl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ struct GraphExecutorImplBase {
7979

8080
virtual const ExecutionPlan& getPlanFor(
8181
Stack& stack,
82-
size_t remaining_bailout_depth) = 0;
82+
c10::optional<size_t> remaining_bailout_depth = c10::nullopt) = 0;
8383
virtual GraphExecutorState getDebugState() = 0;
8484
virtual ~GraphExecutorImplBase() = default;
8585

torch/csrc/jit/runtime/interpreter.cpp

+2-5
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
175175
void callFunction(
176176
Function& f,
177177
Stack& stack,
178-
size_t bailOut = GraphExecutor::getDefaultNumBailOuts(),
178+
c10::optional<size_t> bailOut = c10::nullopt,
179179
bool next = true) {
180180
bool newFrame = f.call(stack, bailOut, [&](const Code& code) {
181181
enterFrame(code, stack.size() - code.num_inputs());
@@ -716,10 +716,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
716716
auto& forked_fn =
717717
toGraphFunction(*frame.function->function_table_[inst.X]);
718718
InterpreterState forked_interpreter(
719-
forked_fn.get_executor()
720-
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
721-
.code,
722-
taskLauncher_);
719+
forked_fn.get_executor().getPlanFor(stack).code, taskLauncher_);
723720
InterpreterContinuation continuation(
724721
forked_interpreter,
725722
Stack(stack.end() - inst.N, stack.end()),

0 commit comments

Comments
 (0)