-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Restore TorchScript functionality (necessary for quantization) (#129)
* Add TorchScript and deepcopy tests * TorchScript fix: ModuleList can only be indexed with literals * TorchScript fix: **kwargs is not allowed * TorchScript fix: we cannot condition on global state So remove our custom workaround for macOS 13.2, this is fixed in 13.3. * TorchScript fix: TorchScript does not allow Module type annotation * TorchScript fixes: many fixes for the Attention class - Ensure TorchScript type inference works. - We can't reference global variables, including errors. - Dataclasses do not work well. - We need __init__ that can be found in source (not synthesized). - The tuple type only works fully specified (not Tuple or Tuple[int, ...]) * Revert "Add support for Torch `scaled_dot_product_attention` (#128)" This reverts commit 68a355a. The functionality introduced in this PR uses global state to detect whether the `scaled_dot_product_attention` is available and check whether the user want to use it. However, we cannot rely on global state in TorchScript. * Attempt to fix CI pip issues * Describe some TorchScript rules of thumb in DEVELOP.md * Simplify TorchScript type inference * Remove unused imports
- Loading branch information
Showing
12 changed files
with
192 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Development | ||
|
||
## TorchScript | ||
|
||
Every `torch.nn.Module` in this project must have a TorchScript conversion test. | ||
TorchScript only supports a subset of Python and we want to make sure that all | ||
models are convertable to TorchScript. | ||
|
||
In this section we will give some rules of thumb to avoid conversion errors. | ||
|
||
### Do not use global state | ||
|
||
TorchScript cannot use global state. One form of global state that we have | ||
in this project is the `Errors` class. Consequently, we cannot use `Errors` | ||
in `Module`s. The following is therefore invalid: | ||
|
||
```python | ||
class Foo(nn.Module): | ||
def forward(X: Tensor) -> Tensor: | ||
# Problem: Errors fields are global state. | ||
raise ValueError(Errors.E042) | ||
``` | ||
|
||
In these cases we have to use an inline string instead: | ||
|
||
```python | ||
class Foo(nn.Module): | ||
def forward(X: Tensor) -> Tensor: | ||
raise ValueError("This module does not do anything yet.") | ||
``` | ||
|
||
For the same reason we can also not rely on `has_*` bools in a module: | ||
|
||
```python | ||
class Foo(nn.Module): | ||
def forward(X: Tensor) -> Tensor: | ||
# Problem: conditional on global state. | ||
if has_torch_feature: | ||
... | ||
``` | ||
|
||
## Typing limitations | ||
|
||
TorchScript only supports a small [subset of Python types](https://pytorch.org/docs/stable/jit_language_reference.html#supported-type). | ||
This also applies to type annotations. For instance, the following will not work, because | ||
TorchScript only supports fully-specified tuple types: | ||
|
||
```python | ||
class Foo(nn.Module): | ||
# Problem: underspecified tuple | ||
def shape(self) -> Tuple: | ||
... | ||
|
||
# Problem: underspecified tuple | ||
def shape(self) -> Tuple[int, ...]: | ||
... | ||
``` | ||
|
||
The following is ok, because it is a valid TorchScript type: | ||
|
||
```python | ||
class Foo(nn.Module): | ||
def shape(self) -> Tuple[int, int]: | ||
... | ||
``` | ||
|
||
## Do not use `**kwargs` arguments | ||
|
||
TorchScript does not support `**kwargs` wildcards. So the following is | ||
invalid: | ||
|
||
```python | ||
class Foo(nn.Module): | ||
... | ||
|
||
def forward(X: Tensor, **kwargs) -> Tensor: | ||
hidden = self.inner1(X) | ||
return self.inner2(hidden, **kwargs) | ||
|
||
``` | ||
|
||
Instead we have to spell out all arguments, eg.: | ||
|
||
```python | ||
class Foo(nn.Module): | ||
... | ||
|
||
def forward(X: Tensor, attn_mask: AttentionMask) -> Tensor: | ||
hidden = self.inner1(X) | ||
return self.inner2(hidden, attn_mask=attn_mask) | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.