Skip to content

Commit b0c27b4

Browse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
Enable backward/forward compatibility for TS runtime (pytorch#57498)
Summary: Pull Request resolved: pytorch#57498 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D28162448 Pulled By: tugsbayasgalan fbshipit-source-id: 5c21ced42a22aca7cee089e876e9d98d32f68955
1 parent b38f153 commit b0c27b4

11 files changed

+116
-48
lines changed

Diff for: test/expect/TestJit.test_import_method.expect

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
def forward(self,
22
x: Tensor,
33
y: Tensor) -> Tensor:
4-
_0 = torch.add(torch.mul(x, 2), y, alpha=1)
5-
return _0
4+
return torch.add(torch.mul(x, 2), y)
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
def loop_use_test(y: Tensor) -> Tuple[Tensor, Tensor]:
2-
x = torch.add(y, 1, 1)
3-
z = torch.add(x, 5, 1)
2+
x = torch.add(y, 1)
3+
z = torch.add(x, 5)
44
z0 = z
55
y0 = y
66
_0 = bool(torch.lt(y, 8))
77
while _0:
8-
y1 = torch.add_(y0, 1, 1)
8+
y1 = torch.add_(y0, 1)
99
_0, z0, y0 = bool(torch.lt(y1, 8)), x, y1
1010
return (x, z0)

Diff for: test/expect/TestJit.test_pretty_printer-while_if_test.expect

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ def while_if_test(a: Tensor,
55
b0 = b
66
_0 = bool(torch.lt(a, 10))
77
while _0:
8-
a1 = torch.add(a0, 1, 1)
9-
b1 = torch.add(b0, 1, 1)
8+
a1 = torch.add(a0, 1)
9+
b1 = torch.add(b0, 1)
1010
if bool(torch.gt(a1, b1)):
1111
c0 = 2
1212
else:
1313
c0 = 3
1414
_0, a0, c, b0 = bool(torch.lt(a1, 10)), a1, c0, b1
15-
return torch.add(torch.add(a0, 1, 1), c, 1)
15+
return torch.add(torch.add(a0, 1), c)

Diff for: test/expect/TestJit.test_pretty_printer-while_test.expect

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ def while_test(a: Tensor,
55
_0 = bool(torch.lt(i, 3))
66
while _0:
77
a1 = torch.mul_(a0, a0)
8-
i1 = torch.add_(i0, 1, 1)
8+
i1 = torch.add_(i0, 1)
99
_0, a0, i0 = bool(torch.lt(i1, 3)), a1, i1
1010
return a0

Diff for: test/jit/test_ignorable_args.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import os
2+
import sys
3+
4+
from torch._C import parse_ir
5+
from torch.testing import FileCheck
6+
7+
# Make the helper files in test/ importable
8+
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
9+
sys.path.append(pytorch_test_dir)
10+
from torch.testing._internal.jit_utils import JitTestCase
11+
12+
if __name__ == '__main__':
13+
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
14+
"\tpython test/test_jit.py TESTNAME\n\n"
15+
"instead.")
16+
17+
# Tests that Python slice class is supported in TorchScript
18+
class TestIgnorableArgs(JitTestCase):
19+
def test_slice_ignorable_args_for_slice(self):
20+
graph_str = """graph():
21+
%15 : int = prim::Constant[value=9223372036854775807]()
22+
%13 : int = prim::Constant[value=0]()
23+
%10 : bool = prim::Constant[value=0]()
24+
%8 : NoneType = prim::Constant()
25+
%0 : int = prim::Constant[value=1]()
26+
%1 : int = prim::Constant[value=2]()
27+
%2 : int = prim::Constant[value=3]()
28+
%3 : int = prim::Constant[value=4]()
29+
%4 : int = prim::Constant[value=9]()
30+
%5 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
31+
%6 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
32+
%7 : int[][] = prim::ListConstruct(%5, %6)
33+
%val.1 : Tensor = aten::tensor(%7, %8, %8, %10)
34+
%16 : Tensor = aten::slice(%val.1, %13, %1, %15, %0)
35+
%20 : Tensor = aten::slice(%16, %0, %13, %0, %0)
36+
return (%20)"""
37+
graph = parse_ir(graph_str)
38+
function = self.createFunctionFromGraph(graph)
39+
function_copy = self.getExportImportCopy(function)
40+
src = str(function.code)
41+
# For a signature:
42+
# aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor
43+
# We ignore trailing arguments after start=2 for dim 0
44+
# and after end=1 for dim 1
45+
# because in %16, %15 and %0 are default values for the schema.
46+
FileCheck().check("torch.slice(torch.tensor(_0), 0, 2), 1, 0, 1)").run(src)
47+
self.assertEqual(function(), function_copy())

Diff for: test/test_jit.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from jit.test_string_formatting import TestStringFormatting # noqa: F401
3636
from jit.test_profiler import TestProfiler # noqa: F401
3737
from jit.test_slice import TestSlice # noqa: F401
38+
from jit.test_ignorable_args import TestIgnorableArgs # noqa: F401
3839
from jit.test_hooks import TestHooks # noqa: F401
3940
from jit.test_warn import TestWarn # noqa: F401
4041
from jit.test_isinstance import TestIsinstance # noqa: F401

Diff for: torch/csrc/jit/runtime/calculate_necessary_args.h

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#pragma once
2+
3+
#include <torch/csrc/WindowsTorchApiMacro.h>
4+
#include <torch/csrc/jit/frontend/schema_matching.h>
5+
#include <cstddef>
6+
7+
namespace torch {
8+
namespace jit {
9+
10+
inline size_t CalculateNecessaryArgs(
11+
const std::vector<Argument>& schema_args,
12+
at::ArrayRef<Value*> actual_inputs) {
13+
if (schema_args.size() < actual_inputs.size()) {
14+
return actual_inputs.size();
15+
}
16+
// keeps track of trailing unnecessary args
17+
int schema_size = schema_args.size();
18+
for (int schema_idx = schema_size - 1; schema_idx > -1; schema_idx--) {
19+
// this means it is not default argument, so it is necessary
20+
if (!schema_args.at(schema_idx).default_value().has_value()) {
21+
return schema_idx + 1;
22+
} else {
23+
auto schema_value =
24+
schema_args.at(schema_idx).default_value().value().toIValue();
25+
// non-const value will become nullptr here, so will be marked necessary
26+
// non-const would include prim::ListConstruct, prim::DictConstruct as
27+
// well.
28+
auto actual_value = toIValue(actual_inputs[schema_idx]);
29+
if (!actual_value.has_value()) {
30+
return schema_idx + 1;
31+
}
32+
// if the IR has same value as default value of the schema,
33+
// it is not neccessary argument.
34+
if (schema_value != actual_value.value()) {
35+
return schema_idx + 1;
36+
}
37+
}
38+
}
39+
return 0;
40+
}
41+
42+
} // namespace jit
43+
} // namespace torch

Diff for: torch/csrc/jit/runtime/interpreter.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,7 @@ const std::vector<Instruction>& Code::instructions() const {
811811
return pImpl->instructions();
812812
}
813813

814-
const std::unordered_map<std::string, int>& Code::op_to_num_specified_args()
814+
const std::unordered_map<std::string, size_t>& Code::op_to_num_specified_args()
815815
const {
816816
return pImpl->op_to_num_specified_args();
817817
}

Diff for: torch/csrc/jit/runtime/interpreter.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ struct TORCH_API Code {
6565
const std::vector<c10::IValue>& constant_table() const;
6666
const std::vector<c10::TypePtr>& type_table() const;
6767
const std::vector<Instruction>& instructions() const;
68-
const std::unordered_map<std::string, int>& op_to_num_specified_args() const;
68+
const std::unordered_map<std::string, size_t>& op_to_num_specified_args()
69+
const;
6970
const std::vector<Node*>& instructions_source() const;
7071
void request_bailout(size_t index);
7172
size_t register_size() const;

Diff for: torch/csrc/jit/runtime/interpreter/code_impl.h

+6-34
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <torch/csrc/jit/ir/ir.h>
88
#include <torch/csrc/jit/jit_log.h>
99
#include <torch/csrc/jit/passes/bailout_graph.h>
10+
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
1011
#include <torch/csrc/jit/runtime/graph_iterator.h>
1112
#include <torch/csrc/jit/runtime/instruction.h>
1213
#include <torch/csrc/jit/runtime/interpreter/preprocess_graph.h>
@@ -101,7 +102,7 @@ struct CodeImpl {
101102
// aten::foo("somestr", arg1=1, arg2=False, arg3=0.0)
102103
// op_to_num_specified_args_["aten::foo.str"] = 3
103104
// This is because for all usages, at most 3 args are used.
104-
std::unordered_map<std::string, int> op_to_num_specified_args_;
105+
std::unordered_map<std::string, size_t> op_to_num_specified_args_;
105106

106107
// running count of uses as we emit. When we reach use_count_[v] =
107108
// v.uses().size() we know it is the final use and we can move rather than
@@ -183,7 +184,8 @@ struct CodeImpl {
183184
return instructions_;
184185
}
185186

186-
const std::unordered_map<std::string, int>& op_to_num_specified_args() const {
187+
const std::unordered_map<std::string, size_t>& op_to_num_specified_args()
188+
const {
187189
return op_to_num_specified_args_;
188190
}
189191

@@ -734,12 +736,12 @@ struct MobileCodeImpl : CodeImpl {
734736
// skip if schema has vararg
735737
if (!op_schema.is_vararg()) {
736738
auto numInclude =
737-
calculate_necessary_args(op_schema.arguments(), node->inputs());
739+
CalculateNecessaryArgs(op_schema.arguments(), node->inputs());
738740
auto unique_name = op_schema.overload_name() != ""
739741
? op_schema.name() + "." + op_schema.overload_name()
740742
: op_schema.name();
741743
auto it = op_to_num_specified_args_.insert(
742-
std::pair<std::string, int>(unique_name, 0));
744+
std::pair<std::string, size_t>(unique_name, 0));
743745
auto prev_value = it.first->second;
744746
it.first->second = std::max(numInclude, prev_value);
745747
}
@@ -748,36 +750,6 @@ struct MobileCodeImpl : CodeImpl {
748750
}
749751
}
750752

751-
int calculate_necessary_args(
752-
const std::vector<Argument>& schema_args,
753-
at::ArrayRef<Value*> actual_inputs) {
754-
AT_ASSERT(schema_args.size() == actual_inputs.size());
755-
// keeps track of trailing unnecessary args
756-
int schema_size = schema_args.size();
757-
for (int schema_idx = schema_size - 1; schema_idx > -1; schema_idx--) {
758-
// this means it is not default argument, so it is necessary
759-
if (!schema_args.at(schema_idx).default_value().has_value()) {
760-
return schema_idx + 1;
761-
} else {
762-
auto schema_value =
763-
schema_args.at(schema_idx).default_value().value().toIValue();
764-
// non-const value will become nullptr here, so will be marked necessary
765-
// non-const would include prim::ListConstruct, prim::DictConstruct as
766-
// well.
767-
auto actual_value = toIValue(actual_inputs[schema_idx]);
768-
if (!actual_value.has_value()) {
769-
return schema_idx + 1;
770-
}
771-
// if the IR has same value as default value of the schema,
772-
// it is not necessary argument.
773-
if (schema_value != actual_value.value()) {
774-
return schema_idx + 1;
775-
}
776-
}
777-
}
778-
return 0;
779-
}
780-
781753
void emitOperator(Node* node) override {
782754
CodeImpl::emitOperator(node);
783755
// const Operator& op = node->getOperator();

Diff for: torch/csrc/jit/serialization/python_print.cpp

+8-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <torch/csrc/jit/ir/ir.h>
1111
#include <torch/csrc/jit/ir/ir_views.h>
1212
#include <torch/csrc/jit/resource_guard.h>
13+
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
1314

1415
#include <algorithm>
1516

@@ -1156,10 +1157,14 @@ struct PythonPrintImpl {
11561157
printOpName(stmt, node->kind());
11571158
const FunctionSchema& schema = node->schema();
11581159
stmt << "(";
1159-
for (size_t i = 0; i < node->inputs().size(); ++i) {
1160-
if (i > 0) {
1160+
// calculate how many args are specified.
1161+
// see (https://github.com/pytorch/pytorch/pull/56079) for more
1162+
// details.
1163+
size_t necessary_args =
1164+
CalculateNecessaryArgs(schema.arguments(), node->inputs());
1165+
for (size_t i = 0; i < necessary_args; ++i) {
1166+
if (i > 0)
11611167
stmt << ", ";
1162-
}
11631168
auto v = useOf(node->inputs().at(i));
11641169
// print the kwarg name if it is a kwarg only argument.
11651170
if (i < schema.arguments().size()) {

0 commit comments

Comments
 (0)