From db76e4e4da8d1ffc9ff99e0f170be8bd9d7868bf Mon Sep 17 00:00:00 2001 From: paul-krug Date: Tue, 6 Feb 2024 18:12:41 +0100 Subject: [PATCH 1/2] small fixes and added unit tests --- pytorch_tcn/tcn.py | 15 +++--- tests/__init__.py | 0 tests/unit/__init__.py | 0 tests/unit/test_tcn.py | 118 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 127 insertions(+), 6 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_tcn.py diff --git a/pytorch_tcn/tcn.py b/pytorch_tcn/tcn.py index d0d24fc..acd98f5 100644 --- a/pytorch_tcn/tcn.py +++ b/pytorch_tcn/tcn.py @@ -187,6 +187,9 @@ def __init__( self.norm2 = None self.conv1 = weight_norm(self.conv1) self.conv2 = weight_norm(self.conv2) + elif use_norm is None: + self.norm1 = None + self.norm2 = None self.activation1 = activation_fn[ self.activation_name ]() self.activation2 = activation_fn[ self.activation_name ]() @@ -269,10 +272,10 @@ def __init__( if dilations is not None and len(dilations) != len(num_channels): raise ValueError("Length of dilations must match length of num_channels") - allowed_norm_values = ['batch_norm', 'layer_norm', 'weight_norm', None] - if use_norm not in allowed_norm_values: + self.allowed_norm_values = ['batch_norm', 'layer_norm', 'weight_norm', None] + if use_norm not in self.allowed_norm_values: raise ValueError( - f"Argument 'use_norm' must be one of: {allowed_norm_values}" + f"Argument 'use_norm' must be one of: {self.allowed_norm_values}" ) if activation not in activation_fn.keys(): @@ -285,10 +288,10 @@ def __init__( f"Argument 'kernel_initializer' must be one of: {kernel_init_fn.keys()}" ) - allowed_input_shapes = ['NCL', 'NLC'] - if input_shape not in allowed_input_shapes: + self.allowed_input_shapes = ['NCL', 'NLC'] + if input_shape not in self.allowed_input_shapes: raise ValueError( - f"Argument 'input_shape' must be one of: {allowed_input_shapes}" + f"Argument 'input_shape' must be one of: {self.allowed_input_shapes}" ) if dilations is None: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_tcn.py b/tests/unit/test_tcn.py new file mode 100644 index 0000000..ef85d9c --- /dev/null +++ b/tests/unit/test_tcn.py @@ -0,0 +1,118 @@ +import unittest +import torch +import pytorch_tcn +from pytorch_tcn import TCN + + +class TestTCN(unittest.TestCase): + + def __init__(self, methodName: str = "runTest") -> None: + super().__init__(methodName) + + self.num_inputs = 20 + self.num_channels = [ + 32, 64, 64, 128, + 32, 64, 64, 128, + ] + + return + + def test_tcn(self, **kwargs): + tcn = TCN( + num_inputs = self.num_inputs, + num_channels = self.num_channels, + **kwargs, + ) + + time_steps = 196 + x = torch.randn( 10, self.num_inputs, time_steps ) + y = tcn(x) + + self.assertEqual( y.shape, (10, self.num_channels[-1], time_steps) ) + return + + def test_kernel_size(self): + self.test_tcn( kernel_size = 7 ) + return + + def test_dilations(self): + # dilations list len != len(num_channels) + with self.assertRaises(ValueError): + self.test_tcn( dilations = [1, 2, 3, 4] ) + + # dilations list len == len(num_channels) + self.test_tcn( dilations = [1, 2, 3, 4, 1, 2, 3, 4] ) + return + + def test_dropout(self): + self.test_tcn( dropout = 0.5 ) + return + + def test_causal(self): + self.test_tcn( causal = True ) + return + + def test_non_causal(self): + self.test_tcn( causal = False ) + return + + def test_norms(self): + available_norms = TCN(10,[10]).allowed_norm_values + for norm in available_norms: + print( 'Testing norm:', norm ) + self.test_tcn( use_norm = norm ) + + with self.assertRaises(ValueError): + self.test_tcn( use_norm = 'invalid' ) + return + + def test_activations(self): + available_activations = pytorch_tcn.tcn.activation_fn.keys() + for activation in available_activations: + self.test_tcn( activation = activation ) + + with self.assertRaises(ValueError): + self.test_tcn( activation = 'invalid' ) + return + + def test_kernel_initializers(self): + available_initializers = pytorch_tcn.tcn.kernel_init_fn.keys() + for initializer in available_initializers: + self.test_tcn( kernel_initializer = initializer ) + + with self.assertRaises(ValueError): + self.test_tcn( kernel_initializer = 'invalid' ) + return + + def test_skip_connections(self): + self.test_tcn( use_skip_connections = True ) + self.test_tcn( use_skip_connections = False ) + return + + def test_input_shape(self): + self.test_tcn( input_shape = 'NCL' ) + + # Test NLC + tcn = TCN( + num_inputs = self.num_inputs, + num_channels = self.num_channels, + input_shape = 'NLC' + ) + + time_steps = 196 + x = torch.randn( 10, time_steps, self.num_inputs, ) + y = tcn(x) + + self.assertEqual( y.shape, (10, time_steps, self.num_channels[-1]) ) + + with self.assertRaises(ValueError): + self.test_tcn( input_shape = 'invalid' ) + return + + + + + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From bdfba2da07d3b2ba9b357aa2da28fdf15046ebd8 Mon Sep 17 00:00:00 2001 From: paul-krug Date: Tue, 6 Feb 2024 18:28:05 +0100 Subject: [PATCH 2/2] Add CI --- .github/workflows/ci.yaml | 40 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 .github/workflows/ci.yaml diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..a3b03c0 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,40 @@ +name: CI + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + run-tests: + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: + - '3.8' + - '3.9' + - '3.10' + - '3.11' + - '3.12' + + name: Test + runs-on: ${{ matrix.os }} + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: python -m pip install . + + - name: Run tests + run: python -m unittest discover \ No newline at end of file