Skip to content

Commit

Permalink
[Pytorch Delegated Backend] Save function name in debug info (pytorch…
Browse files Browse the repository at this point in the history
…#57481)

Summary:
Pull Request resolved: pytorch#57481

This diff introduces function name to InlinedCallStack.
Since we are using InlinedCallStack for debug information in lite
interpreter as well as delegate backends, where InlinedCallStack cannot
be constructed from model source code, we need to save function name.
In the absence of function name Function* is used to get name of the
function. This is when JIT compiles code at runtime.
When that is not possible, this diff introduces a way to obtain function
name.

Test Plan:
test_backend
test_cs_debug_info_serialization

test_backend
test_cs_debug_info_serialization

Imported from OSS

Differential Revision:
D28159097
D28159097

Reviewed By: raziel, ZolotukhinM

Pulled By: kimishpatel

fbshipit-source-id: deacaea3325e27273f92ae96cf0cd0789bbd6e72
  • Loading branch information
kimishpatel authored and facebook-github-bot committed May 25, 2021
1 parent 813adf1 commit ede3f54
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 31 deletions.
12 changes: 6 additions & 6 deletions test/cpp/jit/test_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ Traceback of TorchScript (most recent call last):
return self.A0.forward(x, y) + self.B0.forward(x)
~~~~~~~~~~~~~~~ <--- HERE
File "<string>", line 3, in FunctionName_UNKNOWN
File "<string>", line 3, in forward
def forward(self, x, y):
return x + y
Expand Down Expand Up @@ -352,13 +352,13 @@ Traceback of TorchScript (most recent call last):
return self.B0.forward(x, y) + 3
~~~~~~~~~~~~~~~ <--- HERE
File "<string>", line 3, in FunctionName_UNKNOWN
File "<string>", line 3, in forward
def forward(self, x, y):
return self.A0.forward(x, y) + 2
~~~~~~~~~~~~~~~ <--- HERE
File "<string>", line 3, in FunctionName_UNKNOWN
File "<string>", line 3, in forward
def forward(self, x, y):
return x + y
Expand Down Expand Up @@ -432,7 +432,7 @@ Traceback of TorchScript (most recent call last):
return self.A0.forward(x, y) + self.B0.forward(x)
~~~~~~~~~~~~~~~ <--- HERE
File "<string>", line 5, in FunctionName_UNKNOWN
File "<string>", line 5, in forward
typed_inputs: List[Any] = [x, y, ]
if self.__backend.is_available() :
_0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
Expand Down Expand Up @@ -553,7 +553,7 @@ Traceback of TorchScript (most recent call last):
return self.A0.forward(x, y) + self.B0.forward(x)
~~~~~~~~~~~~~~~ <--- HERE
File "<string>", line 5, in FunctionName_UNKNOWN
File "<string>", line 5, in forward
typed_inputs: List[Any] = [x, y, ]
if self.__backend.is_available() :
_0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
Expand All @@ -566,7 +566,7 @@ Traceback of TorchScript (most recent call last):
return self.AA0.forward(x, y) + 3
~~~~~~~~~~~~~~~~ <--- HERE
File "<string>", line 3, in FunctionName_UNKNOWN
File "<string>", line 3, in forward
def forward(self, x, y):
return x + y
Expand Down
29 changes: 23 additions & 6 deletions test/cpp/jit/test_cs_debug_info_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,37 @@ bool validate_debug_info(
if (vec1.size() != vec2.size()) {
return false;
}
for (size_t i = 0; i < vec1.size(); i++) {
auto rhs_sr = std::get<1>(vec1[i]);
auto lhs_sr = std::get<1>(vec2[i]);
auto rhs_module = std::get<2>(vec1[i]);
auto lhs_module = std::get<2>(vec2[i]);
while (csptr1) {
auto rhs_sr = csptr1->source_range();
auto lhs_sr = csptr2->source_range();
auto rhs_module = csptr1->module_instance();
auto lhs_module = csptr2->module_instance();
std::string rhs_fn_name, lhs_fn_name;
if (csptr1->function()) {
rhs_fn_name = csptr1->function()->name();
} else {
rhs_fn_name = csptr1->function_name();
}
if (csptr2->function()) {
lhs_fn_name = csptr2->function()->name();
} else {
lhs_fn_name = csptr2->function_name();
}
if (!((rhs_module.has_value() == lhs_module.has_value()) &&
(rhs_module.has_value() &&
(rhs_module.value().class_type()->name().value() ==
lhs_module.value().class_type()->name().value()) &&
(rhs_module.value().instance_name() ==
lhs_module.value().instance_name())) &&
(rhs_sr == lhs_sr))) {
(rhs_fn_name == lhs_fn_name) && (rhs_sr == lhs_sr))) {
return false;
}
if (csptr1->callee()) {
csptr1 = csptr1->callee().value();
csptr2 = csptr2->callee().value();
} else {
csptr1 = c10::intrusive_ptr<InlinedCallStack>();
}
}
return true;
}
Expand Down
51 changes: 51 additions & 0 deletions test/cpp/jit/test_lite_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,57 @@ TEST(LiteInterpreterTest, DefaultArgsPinvSpecifyDefault) {
testLiteModuleCompareResultTensors(m, inputs);
}

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
Module a("A");
a.define(R"(
def bar(self, x, y):
return x + y
)");
Module b("B");
b.register_module("A0", a);
b.define(R"(
def foo(self, x, y):
return self.A0.bar(x, y) + 2
)");
Module c("C");
c.register_module("B0", b);
c.define(R"(
def forward(self, x, y):
return self.B0.foo(x, y) + 3
)");

std::vector<IValue> inputs;
inputs.emplace_back(torch::rand({2, 4}));
inputs.emplace_back(torch::rand({13, 9}));

std::stringstream ss;
c._save_for_mobile(ss, ExtraFilesMap(), true);
auto lite_m = _load_for_mobile(ss);
std::string error_pattern = R"(
Module hierarchy:top(C).B0(B).A0(A).aten::add
Traceback of TorchScript (most recent call last):
File "<string>", line 3, in FunctionName_UNKNOWN
def forward(self, x, y):
return self.B0.foo(x, y) + 3
~~~~~~~~~~~ <--- HERE
File "<string>", line 3, in foo
def foo(self, x, y):
return self.A0.bar(x, y) + 2
~~~~~~~~~~~ <--- HERE
File "<string>", line 3, in bar
def bar(self, x, y):
return x + y
~~~~~ <--- HERE
)";
ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern);
}

namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static auto reg =
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ Traceback of TorchScript (most recent call last):
return self.A0.forward(x, y) + self.B0.forward(x)
~~~~~~~~~~~~~~~ <--- HERE
File "<string>", line 5, in FunctionName_UNKNOWN
File "<string>", line 5, in forward
typed_inputs: List[Any] = [x, y, ]
if self.__backend.is_available() :
_0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
Expand All @@ -163,7 +163,7 @@ Traceback of TorchScript (most recent call last):
return self.AA0.forward(x, y) + 3
~~~~~~~~~~~~~~~~ <--- HERE
File "<string>", line 3, in FunctionName_UNKNOWN
File "<string>", line 3, in forward
def forward(self, x, y):
return x + y
Expand Down
36 changes: 32 additions & 4 deletions torch/csrc/jit/ir/scope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,35 @@ InlinedCallStackPtr InlinedCallStack::intrusive_from_this() {
}

InlinedCallStack::InlinedCallStack(Function* fn, SourceRange source_range)
: fn_(fn), source_range_(std::move(source_range)) {}
: fn_(fn), source_range_(std::move(source_range)) {
if (fn_) {
set_function_name(fn_->name());
}
}

InlinedCallStack::InlinedCallStack(
Function* fn,
SourceRange source_range,
c10::optional<ModuleInstanceInfo> module_instance_info)
: fn_(fn),
source_range_(std::move(source_range)),
module_instance_info_(std::move(module_instance_info)) {}
module_instance_info_(std::move(module_instance_info)) {
if (fn_) {
set_function_name(fn_->name());
}
}

InlinedCallStack::InlinedCallStack(
InlinedCallStackPtr callee,
Function* fn,
SourceRange source_range)
: callee_(std::move(callee)),
fn_(fn),
source_range_(std::move(source_range)) {}
source_range_(std::move(source_range)) {
if (fn_) {
set_function_name(fn_->name());
}
}

InlinedCallStack::InlinedCallStack(
InlinedCallStackPtr callee,
Expand All @@ -114,7 +126,11 @@ InlinedCallStack::InlinedCallStack(
: callee_(std::move(callee)),
fn_(fn),
source_range_(std::move(source_range)),
module_instance_info_(std::move(module_instance_info)) {}
module_instance_info_(std::move(module_instance_info)) {
if (fn_) {
set_function_name(fn_->name());
}
}

c10::optional<InlinedCallStackPtr> InlinedCallStack::callee() const {
return callee_;
Expand All @@ -132,6 +148,18 @@ SourceRange InlinedCallStack::source_range() const {
return source_range_;
}

Function* InlinedCallStack::function() const {
return fn_;
}

void InlinedCallStack::set_function_name(std::string fn_name) {
fn_name_ = std::move(fn_name);
}

std::string InlinedCallStack::function_name() const {
return fn_name_;
}

std::vector<InlinedCallStackEntry> InlinedCallStack::vec() {
std::vector<InlinedCallStackEntry> r;
c10::optional<InlinedCallStackPtr> current = intrusive_from_this();
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/ir/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target {
private:
c10::optional<InlinedCallStackPtr> callee_;
Function* fn_;
// Reason for fn_name_ even though we have fn_
// Serialized callstack is used in circustmances where InlinedCallstack
// cannot be constructed during runtime, e.g. mobile runtime or
// delegated backends.
// Since in those cases we do not have Function* we store function name
// fn_name does not give you access to the same information that Function*
// does, however in mobile/delegated backend runtime we use InlindedCallStack
// for exception stack and for that purpose fn_name_ suffices.
std::string fn_name_;
SourceRange source_range_;
InlinedCallStackPtr intrusive_from_this();
c10::optional<ModuleInstanceInfo> module_instance_info_;
Expand Down Expand Up @@ -155,6 +164,12 @@ struct TORCH_API InlinedCallStack : public c10::intrusive_ptr_target {
// Returns the source range of the node
SourceRange source_range() const;

Function* function() const;

void set_function_name(std::string fn_name);

std::string function_name() const;

// Return callstack as a vector of [Function, SourceRange] pairs.
std::vector<InlinedCallStackEntry> vec();

Expand Down
30 changes: 22 additions & 8 deletions torch/csrc/jit/mobile/debug_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ namespace jit {
namespace {

std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy(
const DebugInfoTuple& source_callstack) {
const DebugInfoTuple& source_callstack,
const std::string& caller_name) {
constexpr size_t kSourceRange = 1;
constexpr size_t kModuleInstanceInfo = 2;
std::vector<StackEntry> entries;
Expand All @@ -23,15 +24,15 @@ std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy
std::get<kDebugInfoTupleSourceRangeIndex>(source_callstack);
InlinedCallStackPtr callstack_ptr =
std::get<kDebugInfoTupleInlinedCSIndex>(source_callstack);
std::string prev_function_name = caller_name;
std::string module_info;
if (!callstack_ptr) {
// If not cs then top level node
entries.emplace_back(StackEntry{"FunctionName_UNKNOWN", range});
entries.emplace_back(StackEntry{prev_function_name, range});
return {std::move(entries), std::move(module_info)};
} else {
for (const auto& element : callstack_ptr->vec()) {
const auto& opt_module_instance_info =
std::get<kModuleInstanceInfo>(element);
while (callstack_ptr) {
const auto& opt_module_instance_info = callstack_ptr->module_instance();
if (opt_module_instance_info.has_value()) {
const auto& module_instance_info = opt_module_instance_info.value();
if (module_instance_info.class_type()) {
Expand All @@ -57,9 +58,20 @@ std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy
// When we serialize function names, those can be added here.
// TODO: Add function name separately
entries.emplace_back(
StackEntry{"FunctionName_UNKNOWN", std::get<kSourceRange>(element)});
StackEntry{prev_function_name, callstack_ptr->source_range()});
if (callstack_ptr->function()) {
prev_function_name = callstack_ptr->function()->name();
} else {
prev_function_name = callstack_ptr->function_name();
}

if (callstack_ptr->callee()) {
callstack_ptr = callstack_ptr->callee().value();
} else {
callstack_ptr = c10::intrusive_ptr<InlinedCallStack>();
}
}
entries.emplace_back(StackEntry{"FunctionName_UNKNOWN", range});
entries.emplace_back(StackEntry{prev_function_name, range});
return {std::move(entries), std::move(module_info)};
}
}
Expand All @@ -78,8 +90,10 @@ std::pair<std::string, std::string> getStackTraceWithModuleHierarchy(
std::vector<StackEntry> stack_entries;
std::string module_info =
root_scope_string + "(" + top_module_type_name + ")";
std::string caller_fn_name = "FunctionName_UNKNOWN";
for (const auto& debug_info : source_callstacks) {
auto debug_info_pair = getStackTraceWithModuleHierarchy(debug_info);
auto debug_info_pair =
getStackTraceWithModuleHierarchy(debug_info, caller_fn_name);
auto entries = std::move(debug_info_pair.first);
stack_entries.insert(stack_entries.end(), entries.begin(), entries.end());
module_info += debug_info_pair.second;
Expand Down
Loading

0 comments on commit ede3f54

Please sign in to comment.