Skip to content

Commit cceabe8

Browse files
davidberard98pytorchmergebot
authored andcommitted
[jit] ClassType hashing: hash on compilation_unit as well (pytorch#121928)
Following up on pytorch#121874 - it turns out that in our case, we're seeing repeated class names that are from different compilation units. Our previous hash function wasn't considering the compilation unit, leading to hash collisions (and then exponential memory usage in the number of copies of this class name) Differential Revision: [D54916455](https://our.internmc.facebook.com/intern/diff/D54916455) Pull Request resolved: pytorch#121928 Approved by: https://github.com/eellison ghstack dependencies: pytorch#121874
1 parent 2d9cee2 commit cceabe8

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

test/jit/test_alias_analysis.py

+52
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["oncall: jit"]
22

3+
from torch.testing._internal.common_utils import TemporaryFileName
34
from torch.testing._internal.jit_utils import JitTestCase
45
from torch._C import parse_ir
56
import torch
@@ -91,3 +92,54 @@ def foo2(self, x, y):
9192
inps = list(node.inputs())
9293
self.assertTrue(alias_db.has_writers(inps[1]))
9394
self.assertFalse(alias_db.has_writers(inps[2]))
95+
96+
def test_multiple_compilation_units(self):
97+
# This is a repro of an internal issue we saw.
98+
# Here, we have a large number (40) of modules each with the same name (MyModuleCUTest).
99+
# AliasDB uses some hash tables that hash on types; each of these 40 modules are not
100+
# identical because they have different compilation units, but they have the same name.
101+
# Therefore, if we hash only on the module name (which we previously did), we will have
102+
# hash collisions for all of these module types.
103+
#
104+
# flat_hash_map has very bad performance (exponential) for this hash collision behavior.
105+
# This OOMs prior to the fix.
106+
N = 40
107+
108+
class MultiTmpFile:
109+
def __init__(self, N):
110+
self.N = N
111+
self.ctxs = [TemporaryFileName(mode="w", suffix=".py") for _ in range(N)]
112+
113+
def __enter__(self):
114+
return [x.__enter__() for x in self.ctxs]
115+
116+
def __exit__(self, exc_type, exc_value, traceback):
117+
return [x.__exit__(exc_type, exc_value, traceback) for x in self.ctxs]
118+
119+
class ModuleWrapper(torch.nn.Module):
120+
def __init__(self, module_list):
121+
super().__init__()
122+
self.module_list = module_list
123+
124+
def forward(self, x):
125+
for mod in self.module_list:
126+
x = mod(x)
127+
return x
128+
129+
with MultiTmpFile(N) as fnames:
130+
module_list = torch.nn.ModuleList()
131+
global MyModuleCUTest
132+
133+
class MyModuleCUTest(torch.nn.Module):
134+
def forward(self, x):
135+
return x + 2
136+
137+
for _, fname in enumerate(fnames):
138+
mod = torch.jit.script(MyModuleCUTest())
139+
torch.jit.save(mod, fname)
140+
loaded_mod = torch.jit.load(fname)
141+
module_list.append(loaded_mod)
142+
143+
mod = ModuleWrapper(module_list)
144+
mod = torch.jit.script(mod)
145+
mod(torch.zeros((2, 2)))

torch/csrc/jit/ir/type_hashing.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ namespace torch::jit {
1111
namespace {
1212
size_t hashType(const Type& type) {
1313
if (auto named_type = type.castRaw<ClassType>()) {
14-
return get_hash(named_type->name().value());
14+
return c10::get_hash(
15+
named_type->name().value(), named_type->compilation_unit());
1516
}
1617
size_t hash = 0;
1718
for (const auto& containedType : type.containedTypes()) {

0 commit comments

Comments
 (0)