From 5401140a021ad3029dc9739559d28477cdc54d8f Mon Sep 17 00:00:00 2001 From: paul-krug Date: Thu, 7 Nov 2024 11:29:40 +0100 Subject: [PATCH 1/3] extend bufferio --- pytorch_tcn/__init__.py | 2 +- pytorch_tcn/buffer.py | 12 ++++++++++++ pytorch_tcn/pad.py | 3 +++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pytorch_tcn/__init__.py b/pytorch_tcn/__init__.py index 702d04c..bba73dd 100644 --- a/pytorch_tcn/__init__.py +++ b/pytorch_tcn/__init__.py @@ -2,4 +2,4 @@ from pytorch_tcn.conv import TemporalConv1d from pytorch_tcn.conv import TemporalConvTranspose1d -__version__ = '1.2.2.dev0' \ No newline at end of file +__version__ = '1.2.2.dev1' \ No newline at end of file diff --git a/pytorch_tcn/buffer.py b/pytorch_tcn/buffer.py index bb26126..c74d0d2 100644 --- a/pytorch_tcn/buffer.py +++ b/pytorch_tcn/buffer.py @@ -20,6 +20,7 @@ def __init__( self.in_buffers = None self.out_buffers = [] + self.internal_buffers = [] return def __iter__(self): @@ -37,6 +38,13 @@ def append_out_buffer( ): self.out_buffers.append(x) return + + def append_internal_buffer( + self, + x: torch.Tensor, + ): + self.internal_buffers.append(x) + return def next_in_buffer( self, @@ -44,6 +52,10 @@ def next_in_buffer( return self.__next__() def step(self): + # If in_buffers is None, then the internal buffers are used as input + # After the first step, the operation will continue as usual + if self.in_buffers is None: + self.in_buffers_length = len( self.internal_buffers) if len( self.out_buffers ) != self.in_buffers_length: raise ValueError( """ diff --git a/pytorch_tcn/pad.py b/pytorch_tcn/pad.py index 605f6f0..42dbf3b 100644 --- a/pytorch_tcn/pad.py +++ b/pytorch_tcn/pad.py @@ -129,6 +129,9 @@ def pad_inference( in_buffer = self.buffer else: in_buffer = buffer_io.next_in_buffer() + if in_buffer is None: + in_buffer = self.buffer + buffer_io.append_internal_buffer( in_buffer ) x = torch.cat( (in_buffer, x), From db1ad36d1d58a9f7c4781b318bb51b8cf4409d98 Mon Sep 17 00:00:00 2001 From: paul-krug Date: Sat, 9 Nov 2024 12:21:39 +0100 Subject: [PATCH 2/3] update pad and buffer-io --- pytorch_tcn/pad.py | 34 +++++++++++++++++++++++----------- pytorch_tcn/tcn.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 11 deletions(-) diff --git a/pytorch_tcn/pad.py b/pytorch_tcn/pad.py index 42dbf3b..e673768 100644 --- a/pytorch_tcn/pad.py +++ b/pytorch_tcn/pad.py @@ -78,16 +78,28 @@ def __init__( # Buffer is used for streaming inference if buffer is None: - buffer = torch.zeros( - 1, - in_channels, - self.pad_len, - ) + if in_channels is None: + buffer = torch.zeros( + 1, + self.pad_len, + ) + else: + buffer = torch.zeros( + 1, + in_channels, + self.pad_len, + ) elif isinstance(buffer, (int, float)): - buffer = torch.full( - size = (1, in_channels, self.pad_len), - fill_value = buffer, - ) + if in_channels is None: + buffer = torch.full( + size = (1, self.pad_len), + fill_value = buffer, + ) + else: + buffer = torch.full( + size = (1, in_channels, self.pad_len), + fill_value = buffer, + ) elif not isinstance(buffer, torch.Tensor): raise ValueError( f""" @@ -138,7 +150,7 @@ def pad_inference( -1, ) - out_buffer = x[:, :, -self.pad_len: ] + out_buffer = x[ ..., -self.pad_len: ] if buffer_io is None: self.buffer = out_buffer else: @@ -160,7 +172,7 @@ def forward( def reset_buffer(self): self.buffer.zero_() - if self.buffer.shape[2] != self.pad_len: + if self.buffer.shape[-1] != self.pad_len: raise ValueError( f""" Buffer shape {self.buffer.shape} does not match the expected diff --git a/pytorch_tcn/tcn.py b/pytorch_tcn/tcn.py index ec7fda8..6cba877 100644 --- a/pytorch_tcn/tcn.py +++ b/pytorch_tcn/tcn.py @@ -23,6 +23,7 @@ from typing import Optional from collections.abc import Iterable from pytorch_tcn.conv import TemporalConv1d, TemporalConvTranspose1d +from pytorch_tcn.buffer import BufferIO activation_fn = dict( @@ -177,6 +178,9 @@ def _reset_buffer(x): return def get_buffers(self): + """ + Get all buffers of the network in the order they were created. + """ buffers = [] def _get_buffers(x): if isinstance(x, (TemporalConv1d, TemporalConvTranspose1d) ): @@ -184,6 +188,37 @@ def _get_buffers(x): self.apply(_get_buffers) return buffers + def get_in_buffers(self, *args, **kwargs): + """ + Get all buffers of the network in the order they are used in + the forward pass. This is important for external buffer io, e.g. + with ONNX inference. + """ + # Get the internal buffer state + buffers = self.get_buffers() + # Get the in_buffers via dummy forward pass + buffer_io = BufferIO( in_buffers=None ) + self( + *args, + inference=True, + buffer_io=buffer_io, + **kwargs, + ) + in_buffers = buffer_io.internal_buffers + # Restore the internal buffer state + self.set_buffers( buffers ) + return in_buffers + + def set_buffers(self, buffers): + """ + Set all buffers of the network in the order they were created. + """ + def _set_buffers(x): + if isinstance(x, (TemporalConv1d, TemporalConvTranspose1d) ): + x.buffer = buffers.pop(0) + self.apply(_set_buffers) + return + def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) From 6f4a17ff7a76417246bb5e19510ec5e11e7ab9f4 Mon Sep 17 00:00:00 2001 From: paul-krug Date: Sat, 9 Nov 2024 12:25:31 +0100 Subject: [PATCH 3/3] enhance BaseTCN functions --- pytorch_tcn/tcn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_tcn/tcn.py b/pytorch_tcn/tcn.py index 6cba877..bc75e3b 100644 --- a/pytorch_tcn/tcn.py +++ b/pytorch_tcn/tcn.py @@ -23,6 +23,7 @@ from typing import Optional from collections.abc import Iterable from pytorch_tcn.conv import TemporalConv1d, TemporalConvTranspose1d +from pytorch_tcn.pad import TemporalPad1d from pytorch_tcn.buffer import BufferIO @@ -172,7 +173,7 @@ def _init_weights(m): def reset_buffers(self): def _reset_buffer(x): - if isinstance(x, (TemporalConv1d, TemporalConvTranspose1d) ): + if isinstance(x, (TemporalPad1d,) ): x.reset_buffer() self.apply(_reset_buffer) return @@ -183,7 +184,7 @@ def get_buffers(self): """ buffers = [] def _get_buffers(x): - if isinstance(x, (TemporalConv1d, TemporalConvTranspose1d) ): + if isinstance(x, (TemporalPad1d,) ): buffers.append(x.buffer) self.apply(_get_buffers) return buffers @@ -214,7 +215,7 @@ def set_buffers(self, buffers): Set all buffers of the network in the order they were created. """ def _set_buffers(x): - if isinstance(x, (TemporalConv1d, TemporalConvTranspose1d) ): + if isinstance(x, (TemporalPad1d,) ): x.buffer = buffers.pop(0) self.apply(_set_buffers) return