-
Notifications
You must be signed in to change notification settings - Fork 662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TST parameter fixup #897
Open
andersgb
wants to merge
3
commits into
timeseriesAI:main
Choose a base branch
from
andersgb:tst-parameter-fixup
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
TST parameter fixup #897
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Fix documentation on default value of d_k and d_v. If they are not provided, and d_model and n_heads are kept at default values, they will be set to 128//16 which is 8. Also, specify the usual values range to 8-64 which should more or less correspond to a d_model/n_heads range.
Follow the 050_models.TSTPlus.ipynb implementation and make sure we actually assert what we say we assert on this line. This commit breaks the inline test where d_model is 128 and n_heads is 3, see below. ``` AssertionError in /home/anders/dev/ml/tsai/nbs/049_models.TST.ipynb: =========================================================================== While Executing Cell timeseriesAI#13: --------------------------------------------------------------------------- AssertionError Traceback (most recent call last) Cell In[1], line 2 1 t = torch.rand(16, 50, 128) ----> 2 output = _TSTEncoderLayer(q_len=50, d_model=128, n_heads=3, d_k=None, d_v=None, d_ff=512, dropout=0.1, activation='gelu')(t) 3 output.shape File ~/anaconda3/envs/tsai_dev/lib/python3.9/site-packages/fastcore/meta.py:40, in PrePostInitMeta.__call__(cls, *args, **kwargs) 38 if type(res)==cls: 39 if hasattr(res,'__pre_init__'): res.__pre_init__(*args,**kwargs) ---> 40 res.__init__(*args,**kwargs) 41 if hasattr(res,'__post_init__'): res.__post_init__(*args,**kwargs) 42 return res Cell In[1], line 11, in _TSTEncoderLayer.__init__(self, q_len, d_model, n_heads, d_k, d_v, d_ff, dropout, activation) 8 def __init__(self, q_len:int, d_model:int, n_heads:int, d_k:Optional[int]=None, d_v:Optional[int]=None, d_ff:int=256, dropout:float=0.1, 9 activation:str="gelu"): ---> 11 assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})" 12 d_k = ifnone(d_k, d_model // n_heads) 13 d_v = ifnone(d_v, d_model // n_heads) AssertionError: d_model (128) must be divisible by n_heads (3) ```
I believe this assertion is unnecessary because the dimensions will actually work out even if d_model is not divisible by n_heads. See a model printout of d_model=128, n_heads=3, d_k=11, d_v=9 below. Not saying a parameter change like this is a good idea, but it will happily run and learn on a sample dataset I'm using at least. The in_features and out_features of the entire MHA block will be d_model in any case. Comparing with torch.nn.MultiheadAttention [1] (which is used in the original paper implementation [2]), I think our `d_k*n_heads` corresponds to the `kdim` optional parameter. And, similarly, our `d_v*n_heads` corresponds to the `vdim` parameter. [1] https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html [2] https://github.com/gzerveas/mvts_transformer/blob/3f2e378bc77d02e82a44671f20cf15bc7761671a/src/models/ts_transformer.py#L152 ``` In [2]: clf.model Out[2]: TST( (W_P): Linear(in_features=600, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) (encoder): _TSTEncoder( (layers): ModuleList( (0-2): 3 x _TSTEncoderLayer( (self_attn): _MultiHeadAttention( (W_Q): Linear(in_features=128, out_features=33, bias=False) (W_K): Linear(in_features=128, out_features=33, bias=False) (W_V): Linear(in_features=128, out_features=27, bias=False) (W_O): Linear(in_features=27, out_features=128, bias=False) ) (dropout_attn): Dropout(p=0.1, inplace=False) (batchnorm_attn): Sequential( (0): Transpose(1, 2) (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): Transpose(1, 2) ) (ff): Sequential( (0): Linear(in_features=128, out_features=256, bias=True) (1): GELU(approximate='none') (2): Dropout(p=0.1, inplace=False) (3): Linear(in_features=256, out_features=128, bias=True) ) (dropout_ffn): Dropout(p=0.1, inplace=False) (batchnorm_ffn): Sequential( (0): Transpose(1, 2) (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): Transpose(1, 2) ) ) ) ) (flatten): fastai.layers.Flatten(full=False) (head): Sequential( (0): GELU(approximate='none') (1): fastai.layers.Flatten(full=False) (2): Linear(in_features=2304, out_features=2, bias=True) ) ) ```
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
First commit should be trivially correct (docstring only)
Second commit fixes a bug, which was my original intention for opening this PR
Third commit reintroduces the original behavior because I don't think the assertion is necessary. I tried to understand and compare to how the pytorch implementation of multi head attention handles these dimensions.
So the net change of this PR is only docstring changes. See commit messages for further details.