Skip to content

Commit 7fc2929

Browse files
antoniojkimpytorchmergebot
authored andcommitted
Add support for torch.Generator type in TorchScript (pytorch#110413)
- Add support for `torch.Generator` type in TorchScript - Add `generator` args to all `torch.nn.init` functions that call `uniform_` or `normal_` - Add support for `torch.Generator` in LTC's TorchScript backend (CC: @wconstab) CC: @eellison @davidberard98 @GlebKazantaev @behzad-a Pull Request resolved: pytorch#110413 Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/glebk-cerebras, https://github.com/davidberard98
1 parent b88abb1 commit 7fc2929

39 files changed

+666
-178
lines changed

aten/src/ATen/core/ivalue.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,13 @@ std::ostream& IValue::repr(
644644
c10::printQuotedString(out, device_stream.str());
645645
return out << ")";
646646
}
647+
case IValue::Tag::Generator: {
648+
auto generator = v.toGenerator();
649+
out << "torch.Generator(device=";
650+
c10::printQuotedString(out, generator.device().str());
651+
out << ", seed=" << generator.current_seed() << ")";
652+
return out;
653+
}
647654
case IValue::Tag::GenericDict:
648655
return printMaybeAnnotatedDict(out, v, formatter);
649656
case IValue::Tag::Enum: {
@@ -956,6 +963,7 @@ IValue IValue::deepcopy(
956963
case IValue::Tag::SymBool:
957964
case IValue::Tag::Bool:
958965
case IValue::Tag::Device:
966+
case IValue::Tag::Generator:
959967
case IValue::Tag::Uninitialized: {
960968
copy = *this;
961969
} break;

aten/src/ATen/core/type_factory.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace c10 {
2828
_(complex, ComplexType) \
2929
_(str, StringType) \
3030
_(Device, DeviceObjType) \
31+
_(Generator, GeneratorType) \
3132
_(Stream, StreamObjType) \
3233
_(number, NumberType) \
3334
_(None, NoneType) \

aten/src/ATen/native/ts_native_functions.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ full_codegen:
168168
- slice_scatter
169169
- diagonal_scatter
170170
- as_strided_scatter
171+
# random ops
172+
- normal_functional
173+
- uniform
171174
ir_gen:
172175
- selu
173176
supported:
@@ -177,7 +180,6 @@ supported:
177180
- empty.memory_format
178181
- empty_strided
179182
- fill_.Scalar
180-
- normal_
181183
- max_pool3d_with_indices
182184
- max_pool3d_with_indices_backward
183185
- _to_copy

build_variables.bzl

-1
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,6 @@ lazy_tensor_ts_sources = [
447447
"torch/csrc/lazy/ts_backend/dynamic_ir.cpp",
448448
"torch/csrc/lazy/ts_backend/config.cpp",
449449
"torch/csrc/lazy/ts_backend/ops/device_data.cpp",
450-
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
451450
"torch/csrc/lazy/ts_backend/ops/generic.cpp",
452451
"torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp",
453452
"torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp",

docs/source/jit_unsupported.rst

-1
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,3 @@ we suggest using :meth:`torch.jit.trace`.
8888
* :class:`torch.nn.AdaptiveLogSoftmaxWithLoss`
8989
* :class:`torch.autograd.Function`
9090
* :class:`torch.autograd.enable_grad`
91-
* :class:`torch.Generator`

test/jit/test_generator.py

+195
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Owner(s): ["oncall: jit"]
2+
3+
import io
4+
import math
5+
import unittest
6+
7+
import torch
8+
from torch.nn import init
9+
from torch.testing._internal.common_utils import skipIfLegacyJitExecutor
10+
from torch.testing._internal.jit_utils import JitTestCase
11+
12+
13+
if __name__ == "__main__":
14+
raise RuntimeError(
15+
"This test file is not meant to be run directly, use:\n\n"
16+
"\tpython test/test_jit.py TESTNAME\n\n"
17+
"instead."
18+
)
19+
20+
21+
class TestGenerator(JitTestCase):
22+
# torch.jit.trace does not properly capture the generator manual seed
23+
# and thus is non deterministic even if the generator is manually seeded
24+
@skipIfLegacyJitExecutor("legacy JIT executor does not support Generator type")
25+
@unittest.expectedFailure
26+
def test_trace(self):
27+
def f():
28+
generator = torch.Generator()
29+
generator.seed()
30+
generator.manual_seed(2023)
31+
generator.initial_seed()
32+
tensor = torch.empty(2, 2)
33+
tensor.uniform_(0, 1, generator=generator)
34+
return tensor
35+
36+
traced_f = torch.jit.trace(f, ())
37+
38+
# Run this 3 times to ensure that the generator is being manually seeded
39+
# each time the traced function is run
40+
for i in range(3):
41+
torch.manual_seed(1)
42+
43+
eager_tensor = f()
44+
45+
# Change the seed of the default generator to
46+
# check that we're using the generator from the
47+
# trace
48+
torch.manual_seed(2)
49+
traced_tensor = traced_f()
50+
51+
self.assertEqual(eager_tensor, traced_tensor)
52+
53+
def test_script(self):
54+
def f():
55+
generator = torch.Generator()
56+
generator.seed()
57+
generator.manual_seed(2023)
58+
generator.initial_seed()
59+
tensor = torch.empty(2, 2)
60+
tensor.normal_(-1.0, 1.0, generator=generator)
61+
return tensor
62+
63+
script_f = torch.jit.script(f, ())
64+
65+
# Run this 3 times to ensure that the generator is being manually seeded
66+
# each time the traced function is run
67+
for i in range(3):
68+
torch.manual_seed(1)
69+
70+
eager_tensor = f()
71+
72+
# Change the seed of the default generator to
73+
# check that we're using the generator from the
74+
# trace
75+
torch.manual_seed(2)
76+
77+
script_tensor = script_f()
78+
79+
self.assertEqual(eager_tensor, script_tensor)
80+
81+
def test_default_generator(self):
82+
def f():
83+
# check that calling manual seed for the default generator works
84+
torch.manual_seed(2023)
85+
tensor = torch.empty(2, 2)
86+
tensor.normal_(-1.0, 1.0)
87+
return tensor
88+
89+
torch.manual_seed(1)
90+
91+
eager_tensor = f()
92+
93+
torch.manual_seed(2)
94+
95+
script_f = torch.jit.script(f, ())
96+
script_tensor = script_f()
97+
98+
self.assertEqual(eager_tensor, script_tensor)
99+
100+
def test_generator_arg(self):
101+
def f(generator: torch.Generator):
102+
tensor = torch.empty(2, 2)
103+
tensor.normal_(-1.0, 1.0, generator=generator)
104+
return tensor
105+
106+
generator = torch.Generator()
107+
generator.manual_seed(2023)
108+
109+
script_f = torch.jit.script(f, (generator,))
110+
111+
for i in range(3):
112+
generator = torch.Generator()
113+
generator.manual_seed(2023 + i)
114+
115+
torch.manual_seed(1 + i)
116+
117+
eager_tensor = f(generator)
118+
119+
generator = torch.Generator()
120+
generator.manual_seed(2023 + i)
121+
122+
torch.manual_seed(1 + i)
123+
124+
script_tensor = script_f(generator)
125+
126+
self.assertEqual(eager_tensor, script_tensor)
127+
128+
def test_save_load(self):
129+
class Foo(torch.nn.Module):
130+
def __init__(self):
131+
super().__init__()
132+
self.foo = torch.nn.Linear(2, 2, bias=False)
133+
self.bar = torch.nn.Linear(2, 2, bias=False)
134+
135+
self.reset_parameters()
136+
137+
def reset_linear(self, module, generator):
138+
init.kaiming_uniform_(
139+
module.weight, a=math.sqrt(5), generator=generator
140+
)
141+
142+
def reset_parameters(self):
143+
generator = torch.Generator()
144+
generator.manual_seed(1)
145+
self.reset_linear(self.foo, generator)
146+
147+
generator = torch.Generator()
148+
generator.manual_seed(2)
149+
self.reset_linear(self.bar, generator)
150+
151+
def forward(self, x):
152+
x = self.foo(x)
153+
x = self.bar(x)
154+
155+
generator = torch.Generator()
156+
generator.manual_seed(3)
157+
r = torch.empty_like(x)
158+
r.normal_(0.0, 1.0, generator=generator)
159+
160+
return x, r
161+
162+
eager_foo = Foo()
163+
164+
script_module = torch.jit.script(Foo())
165+
saved_module = io.BytesIO()
166+
torch.jit.save(script_module, saved_module)
167+
saved_module.seek(0)
168+
169+
loaded_module = torch.jit.load(saved_module)
170+
171+
self.assertEqual(eager_foo.foo.weight, loaded_module.foo.weight)
172+
self.assertEqual(eager_foo.bar.weight, loaded_module.bar.weight)
173+
174+
try:
175+
# Run this 3 times so make sure that the generator seed is being set
176+
# every time forward is called
177+
for i in range(3):
178+
x = torch.ones(2, 2)
179+
out1, r1 = eager_foo(x)
180+
out2, r2 = loaded_module(x)
181+
182+
try:
183+
self.assertEqual(out1, out2)
184+
except: # noqa: B001, E722
185+
print(f"Iteration {i}:\n{out1=}\n{out2=}")
186+
raise
187+
188+
try:
189+
self.assertEqual(r1, r2)
190+
except: # noqa: B001, E722
191+
print(f"Iteration {i}:\n{r1=}\n{r2=}")
192+
raise
193+
except: # noqa: B001, E722
194+
print(loaded_module.forward.code)
195+
raise

test/lazy/test_generator.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Owner(s): ["oncall: jit"]
2+
3+
import torch
4+
import torch._lazy.metrics as metrics
5+
import torch._lazy.ts_backend
6+
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
7+
8+
torch._lazy.ts_backend.init()
9+
10+
11+
class LazyGeneratorTest(TestCase):
12+
def test_generator(self):
13+
"""
14+
Test that generators are being inserted into the TorchScript
15+
graph by setting different seeds before each call to
16+
generate_tensor but the resulting tensor is the same
17+
"""
18+
19+
def generate_tensor():
20+
g1 = torch.Generator()
21+
g1.manual_seed(2023)
22+
t1 = torch.tensor(1.0)
23+
t1.uniform_(generator=g1)
24+
25+
g2 = torch.Generator()
26+
g2.manual_seed(2024)
27+
t2 = torch.tensor(1.0)
28+
t2.normal_(generator=g2)
29+
30+
return t1, t2
31+
32+
torch.manual_seed(1)
33+
34+
with torch.device("cpu"):
35+
cpu_t1, cpu_t2 = generate_tensor()
36+
37+
torch.manual_seed(2)
38+
39+
with torch.device("lazy"):
40+
lazy_t1, lazy_t2 = generate_tensor()
41+
42+
torch._lazy.mark_step()
43+
44+
assert torch.allclose(
45+
cpu_t1, lazy_t1.to("cpu")
46+
), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
47+
assert torch.allclose(
48+
cpu_t2, lazy_t2.to("cpu")
49+
), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
50+
51+
@skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type")
52+
def test_generator_causes_multiple_compiles(self):
53+
"""
54+
Test that inserting generators with different seed caused recompile
55+
"""
56+
57+
def generate_tensor(seed):
58+
t = torch.tensor(1.0)
59+
g = torch.Generator()
60+
g.manual_seed(seed)
61+
t.uniform_(-1, 1, generator=g)
62+
return t
63+
64+
metrics.reset()
65+
66+
with torch.device("lazy"):
67+
t = generate_tensor(1)
68+
torch._lazy.mark_step()
69+
70+
uncached_compile = metrics.counter_value("UncachedCompile")
71+
assert (
72+
uncached_compile == 1
73+
), f"Expected 1 uncached compiles, got {uncached_compile}"
74+
75+
t = generate_tensor(2)
76+
torch._lazy.mark_step()
77+
78+
uncached_compile = metrics.counter_value("UncachedCompile")
79+
assert (
80+
uncached_compile == 2
81+
), f"Expected 2 uncached compiles, got {uncached_compile}"
82+
83+
t = generate_tensor(1)
84+
torch._lazy.mark_step()
85+
86+
uncached_compile = metrics.counter_value("UncachedCompile")
87+
assert (
88+
uncached_compile == 2
89+
), f"Expected 2 uncached compiles, got {uncached_compile}"
90+
cached_compile = metrics.counter_value("CachedCompile")
91+
assert (
92+
cached_compile == 1
93+
), f"Expected 1 cached compile, got {cached_compile}"
94+
95+
metrics.reset()
96+
97+
latest_graph = torch._C._lazy_ts_backend._get_latest_computation_graph()
98+
assert 'torch.Generator(device="cpu", seed=1)' in latest_graph
99+
assert "aten::uniform" in latest_graph
100+
101+
102+
if __name__ == "__main__":
103+
run_tests()

test/lazy/test_ts_opinfo.py

+7
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,17 @@ def assert_allclose_rec(t):
231231

232232
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
233233
for sample in samples:
234+
# Need to run mark step so that all random ops are computed in the right order
235+
torch._lazy.mark_step()
236+
234237
args = [sample.input] + list(sample.args)
235238
kwargs = sample.kwargs
236239
copy_args = clone_to_device(args, test_device)
237240

238241
r_exp = op(*copy_args, **kwargs)
239242
r_actual = op(*args, **kwargs)
240243

244+
torch._lazy.mark_step()
241245
assert_allclose_rec((r_actual, r_exp))
242246

243247
@ops([op for op in op_db if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST], allowed_dtypes=(torch.float,)) # noqa: B950
@@ -263,6 +267,9 @@ def assert_allclose_rec(t):
263267

264268
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
265269
for sample in samples:
270+
# Need to run mark step so that all random ops are computed in the right order
271+
torch._lazy.mark_step()
272+
266273
args = [sample.input] + list(sample.args)
267274
kwargs = sample.kwargs
268275
copy_args = clone_to_device(args, test_device)

0 commit comments

Comments
 (0)