|
1 | 1 | # Owner(s): ["oncall: jit"]
|
2 | 2 |
|
| 3 | +from torch.testing._internal.common_utils import TemporaryFileName |
3 | 4 | from torch.testing._internal.jit_utils import JitTestCase
|
4 | 5 | from torch._C import parse_ir
|
5 | 6 | import torch
|
@@ -91,3 +92,54 @@ def foo2(self, x, y):
|
91 | 92 | inps = list(node.inputs())
|
92 | 93 | self.assertTrue(alias_db.has_writers(inps[1]))
|
93 | 94 | self.assertFalse(alias_db.has_writers(inps[2]))
|
| 95 | + |
| 96 | + def test_multiple_compilation_units(self): |
| 97 | + # This is a repro of an internal issue we saw. |
| 98 | + # Here, we have a large number (40) of modules each with the same name (MyModuleCUTest). |
| 99 | + # AliasDB uses some hash tables that hash on types; each of these 40 modules are not |
| 100 | + # identical because they have different compilation units, but they have the same name. |
| 101 | + # Therefore, if we hash only on the module name (which we previously did), we will have |
| 102 | + # hash collisions for all of these module types. |
| 103 | + # |
| 104 | + # flat_hash_map has very bad performance (exponential) for this hash collision behavior. |
| 105 | + # This OOMs prior to the fix. |
| 106 | + N = 40 |
| 107 | + |
| 108 | + class MultiTmpFile: |
| 109 | + def __init__(self, N): |
| 110 | + self.N = N |
| 111 | + self.ctxs = [TemporaryFileName(mode="w", suffix=".py") for _ in range(N)] |
| 112 | + |
| 113 | + def __enter__(self): |
| 114 | + return [x.__enter__() for x in self.ctxs] |
| 115 | + |
| 116 | + def __exit__(self, exc_type, exc_value, traceback): |
| 117 | + return [x.__exit__(exc_type, exc_value, traceback) for x in self.ctxs] |
| 118 | + |
| 119 | + class ModuleWrapper(torch.nn.Module): |
| 120 | + def __init__(self, module_list): |
| 121 | + super().__init__() |
| 122 | + self.module_list = module_list |
| 123 | + |
| 124 | + def forward(self, x): |
| 125 | + for mod in self.module_list: |
| 126 | + x = mod(x) |
| 127 | + return x |
| 128 | + |
| 129 | + with MultiTmpFile(N) as fnames: |
| 130 | + module_list = torch.nn.ModuleList() |
| 131 | + global MyModuleCUTest |
| 132 | + |
| 133 | + class MyModuleCUTest(torch.nn.Module): |
| 134 | + def forward(self, x): |
| 135 | + return x + 2 |
| 136 | + |
| 137 | + for _, fname in enumerate(fnames): |
| 138 | + mod = torch.jit.script(MyModuleCUTest()) |
| 139 | + torch.jit.save(mod, fname) |
| 140 | + loaded_mod = torch.jit.load(fname) |
| 141 | + module_list.append(loaded_mod) |
| 142 | + |
| 143 | + mod = ModuleWrapper(module_list) |
| 144 | + mod = torch.jit.script(mod) |
| 145 | + mod(torch.zeros((2, 2))) |
0 commit comments