Skip to content

Commit

Permalink
add option to set frame padding for 3D CCT (#339)
Browse files Browse the repository at this point in the history
  • Loading branch information
kalekundert authored Jan 4, 2025
1 parent e7cba9b commit b7ed6ba
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions vit_pytorch/cct_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ def __init__(
stride,
padding,
frame_stride=1,
frame_padding=None,
frame_pooling_stride=1,
frame_pooling_kernel_size=1,
frame_pooling_padding=None,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
Expand All @@ -188,16 +190,22 @@ def __init__(

n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])

if frame_padding is None:
frame_padding = frame_kernel_size // 2

if frame_pooling_padding is None:
frame_pooling_padding = frame_pooling_kernel_size // 2

self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv3d(chan_in, chan_out,
kernel_size=(frame_kernel_size, kernel_size, kernel_size),
stride=(frame_stride, stride, stride),
padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
padding=(frame_padding, padding, padding), bias=conv_bias),
nn.Identity() if not exists(activation) else activation(),
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
stride=(frame_pooling_stride, pooling_stride, pooling_stride),
padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
padding=(frame_pooling_padding, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
)
for chan_in, chan_out in n_filter_list_pairs
])
Expand Down Expand Up @@ -324,8 +332,10 @@ def __init__(
n_conv_layers=1,
frame_stride=1,
frame_kernel_size=3,
frame_padding=None,
frame_pooling_kernel_size=1,
frame_pooling_stride=1,
frame_pooling_padding=None,
kernel_size=7,
stride=2,
padding=3,
Expand All @@ -342,8 +352,10 @@ def __init__(
n_output_channels=embedding_dim,
frame_stride=frame_stride,
frame_kernel_size=frame_kernel_size,
frame_padding=frame_padding,
frame_pooling_stride=frame_pooling_stride,
frame_pooling_kernel_size=frame_pooling_kernel_size,
frame_pooling_padding=frame_pooling_padding,
kernel_size=kernel_size,
stride=stride,
padding=padding,
Expand Down

0 comments on commit b7ed6ba

Please sign in to comment.