diff --git a/pytorch_tcn/__init__.py b/pytorch_tcn/__init__.py index 702d04c..bba73dd 100644 --- a/pytorch_tcn/__init__.py +++ b/pytorch_tcn/__init__.py @@ -2,4 +2,4 @@ from pytorch_tcn.conv import TemporalConv1d from pytorch_tcn.conv import TemporalConvTranspose1d -__version__ = '1.2.2.dev0' \ No newline at end of file +__version__ = '1.2.2.dev1' \ No newline at end of file diff --git a/pytorch_tcn/buffer.py b/pytorch_tcn/buffer.py index bb26126..c74d0d2 100644 --- a/pytorch_tcn/buffer.py +++ b/pytorch_tcn/buffer.py @@ -20,6 +20,7 @@ def __init__( self.in_buffers = None self.out_buffers = [] + self.internal_buffers = [] return def __iter__(self): @@ -37,6 +38,13 @@ def append_out_buffer( ): self.out_buffers.append(x) return + + def append_internal_buffer( + self, + x: torch.Tensor, + ): + self.internal_buffers.append(x) + return def next_in_buffer( self, @@ -44,6 +52,10 @@ def next_in_buffer( return self.__next__() def step(self): + # If in_buffers is None, then the internal buffers are used as input + # After the first step, the operation will continue as usual + if self.in_buffers is None: + self.in_buffers_length = len( self.internal_buffers) if len( self.out_buffers ) != self.in_buffers_length: raise ValueError( """ diff --git a/pytorch_tcn/pad.py b/pytorch_tcn/pad.py index 605f6f0..e673768 100644 --- a/pytorch_tcn/pad.py +++ b/pytorch_tcn/pad.py @@ -78,16 +78,28 @@ def __init__( # Buffer is used for streaming inference if buffer is None: - buffer = torch.zeros( - 1, - in_channels, - self.pad_len, - ) + if in_channels is None: + buffer = torch.zeros( + 1, + self.pad_len, + ) + else: + buffer = torch.zeros( + 1, + in_channels, + self.pad_len, + ) elif isinstance(buffer, (int, float)): - buffer = torch.full( - size = (1, in_channels, self.pad_len), - fill_value = buffer, - ) + if in_channels is None: + buffer = torch.full( + size = (1, self.pad_len), + fill_value = buffer, + ) + else: + buffer = torch.full( + size = (1, in_channels, self.pad_len), + fill_value = buffer, + ) elif not isinstance(buffer, torch.Tensor): raise ValueError( f""" @@ -129,13 +141,16 @@ def pad_inference( in_buffer = self.buffer else: in_buffer = buffer_io.next_in_buffer() + if in_buffer is None: + in_buffer = self.buffer + buffer_io.append_internal_buffer( in_buffer ) x = torch.cat( (in_buffer, x), -1, ) - out_buffer = x[:, :, -self.pad_len: ] + out_buffer = x[ ..., -self.pad_len: ] if buffer_io is None: self.buffer = out_buffer else: @@ -157,7 +172,7 @@ def forward( def reset_buffer(self): self.buffer.zero_() - if self.buffer.shape[2] != self.pad_len: + if self.buffer.shape[-1] != self.pad_len: raise ValueError( f""" Buffer shape {self.buffer.shape} does not match the expected diff --git a/pytorch_tcn/tcn.py b/pytorch_tcn/tcn.py index ec7fda8..bc75e3b 100644 --- a/pytorch_tcn/tcn.py +++ b/pytorch_tcn/tcn.py @@ -23,6 +23,8 @@ from typing import Optional from collections.abc import Iterable from pytorch_tcn.conv import TemporalConv1d, TemporalConvTranspose1d +from pytorch_tcn.pad import TemporalPad1d +from pytorch_tcn.buffer import BufferIO activation_fn = dict( @@ -171,19 +173,53 @@ def _init_weights(m): def reset_buffers(self): def _reset_buffer(x): - if isinstance(x, (TemporalConv1d, TemporalConvTranspose1d) ): + if isinstance(x, (TemporalPad1d,) ): x.reset_buffer() self.apply(_reset_buffer) return def get_buffers(self): + """ + Get all buffers of the network in the order they were created. + """ buffers = [] def _get_buffers(x): - if isinstance(x, (TemporalConv1d, TemporalConvTranspose1d) ): + if isinstance(x, (TemporalPad1d,) ): buffers.append(x.buffer) self.apply(_get_buffers) return buffers + def get_in_buffers(self, *args, **kwargs): + """ + Get all buffers of the network in the order they are used in + the forward pass. This is important for external buffer io, e.g. + with ONNX inference. + """ + # Get the internal buffer state + buffers = self.get_buffers() + # Get the in_buffers via dummy forward pass + buffer_io = BufferIO( in_buffers=None ) + self( + *args, + inference=True, + buffer_io=buffer_io, + **kwargs, + ) + in_buffers = buffer_io.internal_buffers + # Restore the internal buffer state + self.set_buffers( buffers ) + return in_buffers + + def set_buffers(self, buffers): + """ + Set all buffers of the network in the order they were created. + """ + def _set_buffers(x): + if isinstance(x, (TemporalPad1d,) ): + x.buffer = buffers.pop(0) + self.apply(_set_buffers) + return + def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2)