Skip to content

Commit

Permalink
Merge pull request #22 from paul-krug/extend-bufferio-functionality
Browse files Browse the repository at this point in the history
Extend bufferio functionality
  • Loading branch information
paul-krug authored Nov 9, 2024
2 parents bb8a04f + 6f4a17f commit 3b3deac
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pytorch_tcn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from pytorch_tcn.conv import TemporalConv1d
from pytorch_tcn.conv import TemporalConvTranspose1d

__version__ = '1.2.2.dev0'
__version__ = '1.2.2.dev1'
12 changes: 12 additions & 0 deletions pytorch_tcn/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
self.in_buffers = None

self.out_buffers = []
self.internal_buffers = []
return

def __iter__(self):
Expand All @@ -37,13 +38,24 @@ def append_out_buffer(
):
self.out_buffers.append(x)
return

def append_internal_buffer(
self,
x: torch.Tensor,
):
self.internal_buffers.append(x)
return

def next_in_buffer(
self,
):
return self.__next__()

def step(self):
# If in_buffers is None, then the internal buffers are used as input
# After the first step, the operation will continue as usual
if self.in_buffers is None:
self.in_buffers_length = len( self.internal_buffers)
if len( self.out_buffers ) != self.in_buffers_length:
raise ValueError(
"""
Expand Down
37 changes: 26 additions & 11 deletions pytorch_tcn/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,28 @@ def __init__(

# Buffer is used for streaming inference
if buffer is None:
buffer = torch.zeros(
1,
in_channels,
self.pad_len,
)
if in_channels is None:
buffer = torch.zeros(
1,
self.pad_len,
)
else:
buffer = torch.zeros(
1,
in_channels,
self.pad_len,
)
elif isinstance(buffer, (int, float)):
buffer = torch.full(
size = (1, in_channels, self.pad_len),
fill_value = buffer,
)
if in_channels is None:
buffer = torch.full(
size = (1, self.pad_len),
fill_value = buffer,
)
else:
buffer = torch.full(
size = (1, in_channels, self.pad_len),
fill_value = buffer,
)
elif not isinstance(buffer, torch.Tensor):
raise ValueError(
f"""
Expand Down Expand Up @@ -129,13 +141,16 @@ def pad_inference(
in_buffer = self.buffer
else:
in_buffer = buffer_io.next_in_buffer()
if in_buffer is None:
in_buffer = self.buffer
buffer_io.append_internal_buffer( in_buffer )

x = torch.cat(
(in_buffer, x),
-1,
)

out_buffer = x[:, :, -self.pad_len: ]
out_buffer = x[ ..., -self.pad_len: ]
if buffer_io is None:
self.buffer = out_buffer
else:
Expand All @@ -157,7 +172,7 @@ def forward(

def reset_buffer(self):
self.buffer.zero_()
if self.buffer.shape[2] != self.pad_len:
if self.buffer.shape[-1] != self.pad_len:
raise ValueError(
f"""
Buffer shape {self.buffer.shape} does not match the expected
Expand Down
40 changes: 38 additions & 2 deletions pytorch_tcn/tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from typing import Optional
from collections.abc import Iterable
from pytorch_tcn.conv import TemporalConv1d, TemporalConvTranspose1d
from pytorch_tcn.pad import TemporalPad1d
from pytorch_tcn.buffer import BufferIO


activation_fn = dict(
Expand Down Expand Up @@ -171,19 +173,53 @@ def _init_weights(m):

def reset_buffers(self):
def _reset_buffer(x):
if isinstance(x, (TemporalConv1d, TemporalConvTranspose1d) ):
if isinstance(x, (TemporalPad1d,) ):
x.reset_buffer()
self.apply(_reset_buffer)
return

def get_buffers(self):
"""
Get all buffers of the network in the order they were created.
"""
buffers = []
def _get_buffers(x):
if isinstance(x, (TemporalConv1d, TemporalConvTranspose1d) ):
if isinstance(x, (TemporalPad1d,) ):
buffers.append(x.buffer)
self.apply(_get_buffers)
return buffers

def get_in_buffers(self, *args, **kwargs):
"""
Get all buffers of the network in the order they are used in
the forward pass. This is important for external buffer io, e.g.
with ONNX inference.
"""
# Get the internal buffer state
buffers = self.get_buffers()
# Get the in_buffers via dummy forward pass
buffer_io = BufferIO( in_buffers=None )
self(
*args,
inference=True,
buffer_io=buffer_io,
**kwargs,
)
in_buffers = buffer_io.internal_buffers
# Restore the internal buffer state
self.set_buffers( buffers )
return in_buffers

def set_buffers(self, buffers):
"""
Set all buffers of the network in the order they were created.
"""
def _set_buffers(x):
if isinstance(x, (TemporalPad1d,) ):
x.buffer = buffers.pop(0)
self.apply(_set_buffers)
return


def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
Expand Down

0 comments on commit 3b3deac

Please sign in to comment.