Skip to content

Commit 759e37f

Browse files
committed
Added time series generation example
1 parent 75afe95 commit 759e37f

7 files changed

+60
-35
lines changed

Diff for: .gitignore

+2-2
Original file line numberDiff line numberDiff line change
@@ -128,5 +128,5 @@ dmypy.json
128128
# CI/ CD
129129
exported/
130130

131-
# Trained models
132-
examples/model.pt
131+
# Results from the example scripts
132+
examples/*.pt

Diff for: README.md

+19-13
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,21 @@ It's main characteristic is the lateral in-/exhibition of neurons though their a
1919
Due to its simplicity and expressiveness, Amari’s work was highly influential and led to several follow-up papers such
2020
as [2-6] to only name a few.
2121

22+
## Support
23+
24+
If you use code or ideas from this repository for your projects or research, **please cite it**.
25+
26+
```
27+
@misc{Muratore_neuralfields,
28+
author = {Fabio Muratore},
29+
title = {neuralfields - A type of potential-based recurrent neural networks implemented with PyTorch},
30+
year = {2023},
31+
publisher = {GitHub},
32+
journal = {GitHub repository},
33+
howpublished = {\url{https://github.com/famura/neuralfields}}
34+
}
35+
```
36+
2237
## Features
2338

2439
* There are two variants of the neural fields implemented in this repository: one called `NeuralField` that matches
@@ -36,20 +51,11 @@ as [2-6] to only name a few.
3651
sim-to-real transfer. However, the goal of this repository is to make the implementation **as general as possible**,
3752
such that it could for example be used as generative model.
3853

39-
## Support
40-
41-
If you use code or ideas from this repository for your projects or research, **please cite it**.
54+
### Time series learning example
55+
![](examples/time_series_learning.png) ![](exported/examples/time_series_learning.png)
4256

43-
```
44-
@misc{Muratore_neuralfields,
45-
author = {Fabio Muratore},
46-
title = {neuralfields - A type of potential-based recurrent neural networks implemented with PyTorch},
47-
year = {2023},
48-
publisher = {GitHub},
49-
journal = {GitHub repository},
50-
howpublished = {\url{https://github.com/famura/neuralfields}}
51-
}
52-
```
57+
### Time series generation example
58+
![](examples/time_series_generation.png) ![](exported/examples/time_series_generation.png)
5359

5460
## Getting Started
5561

Diff for: examples/time_series_generation.png

199 KB
Loading

Diff for: examples/time_series_generation.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import matplotlib.pyplot as plt
24
import seaborn
35
import torch
@@ -11,9 +13,11 @@
1113
# Configure.
1214
torch.manual_seed(0)
1315
num_samples = 5
14-
len_time_series = 50
15-
random_inputs = True # if False, the inputs will be zero for every sample (vary via the hidden state)
16-
default_hidden = False # if True, the hidden state will be initialized with zeroes by the net (vary via the inputs)
16+
use_test_inputs = True # if False, the inputs will be zero for every sample (vary via the hidden state)
17+
default_hidden = True # if True, the hidden state will be initialized with zeroes by the net (vary via the inputs)
18+
if not use_test_inputs and default_hidden:
19+
warnings.warn("All generated sequences will be the same. Please change the scripts configuration.")
20+
len_time_series = 800 if use_test_inputs else 10
1721

1822
# Load the model previously trained with the time_series_learning.py example script.
1923
try:
@@ -27,18 +31,22 @@
2731

2832
# Use the model to generate several time series. This could either be done by providing inputs that are different
2933
# along the first dimension of the tensor, or by varying the initial hidden state.
30-
inputs = torch.zeros(num_samples, len_time_series, model.input_size)
31-
if random_inputs:
32-
inputs[:, len_time_series // 2 :, :] = (
33-
torch.linspace(0, 8, len_time_series // 2) # start and end of the linspace are arbitrary
34-
.view(1, -1, 1)
35-
.repeat(num_samples, 1, model.input_size)
36-
)
37-
inputs[:, len_time_series // 2 :, :] += torch.randn(num_samples, len_time_series // 2, model.input_size)
34+
if use_test_inputs:
35+
try:
36+
inputs = torch.load(EXAMPLES_DIR / "data_tst.pt").unsqueeze(0).repeat(num_samples, 1, 1)
37+
except FileNotFoundError:
38+
raise FileNotFoundError(
39+
"There was no file called 'data_tst.pt' found in neuralfields' example directory. Most likely, you "
40+
"need to run the 'time_series_learning.py' script first."
41+
)
42+
inputs = inputs[:, :len_time_series, :]
43+
inputs[1:] += torch.randn(num_samples - 1, len_time_series, model.input_size) / 50
44+
else:
45+
inputs = torch.zeros(num_samples, len_time_series, model.input_size)
3846
if default_hidden:
3947
hidden = None
4048
else:
41-
hidden = torch.randn(num_samples, model.hidden_size) * 5
49+
hidden = torch.randn(num_samples, model.hidden_size) * 10
4250
with torch.no_grad():
4351
generated, _ = model(inputs, hidden)
4452

@@ -47,4 +55,6 @@
4755
for idx_ts, gen_ts in enumerate(generated):
4856
plt.plot(gen_ts.numpy(), label=f"sample {idx_ts}")
4957
axs.legend(bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower left", mode="expand", borderaxespad=0, ncol=inputs.size(0))
58+
axs.set_xlabel("months")
59+
axs.set_ylabel("spot count")
5060
plt.show()

Diff for: examples/time_series_learning.png

175 KB
Loading

Diff for: examples/time_series_learning.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ def create_figure(dataset_name: str) -> plt.Figure:
3939
return fig
4040

4141

42+
def plot_results(fig: plt.Figure) -> None:
43+
axs = fig.get_axes()
44+
axs[0].plot(predictions_trn, label="predictions")
45+
axs[1].plot(predictions_tst, label="predictions")
46+
axs[0].legend(loc="upper right", ncol=2)
47+
axs[1].legend(loc="upper right", ncol=2)
48+
49+
4250
def simple_training_loop(
4351
model: torch.nn.Module,
4452
packed_inputs: torch.Tensor,
@@ -79,8 +87,8 @@ def simple_training_loop(
7987
# Configure.
8088
torch.manual_seed(0)
8189
use_simplification = False # switch between models
82-
normalize_data = False # scales the data to be in [-1, 1]
83-
dataset_name = "mackey_glass" # monthly_sunspots or mackey_glass
90+
normalize_data = True # scales the data to be in [-1, 1]
91+
dataset_name = "monthly_sunspots" # monthly_sunspots or mackey_glass
8492

8593
# Get the data.
8694
data, data_trn, data_tst = load_and_split_data(dataset_name, normalize_data)
@@ -135,13 +143,11 @@ def simple_training_loop(
135143
predictions_tst, _ = model(data_tst[:-1].unsqueeze(0), hidden=None)
136144
predictions_tst = predictions_tst.squeeze(0).detach().numpy()
137145

138-
# Safe the model.
146+
# Safe the model and the associated data.
139147
torch.save(model, EXAMPLES_DIR / "model.pt")
148+
torch.save(data_trn, EXAMPLES_DIR / "data_trn.pt")
149+
torch.save(data_tst, EXAMPLES_DIR / "data_tst.pt")
140150

141151
# Plot the results.
142-
axs = fig.get_axes()
143-
axs[0].plot(predictions_trn, label="predictions")
144-
axs[1].plot(predictions_tst, label="predictions")
145-
axs[0].legend(loc="upper right", ncol=2)
146-
axs[1].legend(loc="upper right", ncol=2)
152+
plot_results(fig)
147153
plt.show()

Diff for: pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ sequence = [
270270
{ cmd = "genbadge tests --input-file pytest.xml --output-file docs/exported/tests/badge.svg" },
271271
{ cmd = "pip-licenses --format markdown --with-authors --with-urls --with-description --output-file docs/exported/third_party_licenses.md" },
272272
{ shell = "poe --help > docs/exported/poe_options.txt" },
273+
{ shell = "mkdir -p docs/exported/examples" },
274+
{ shell = "cp examples/time_series_learning.png docs/exported/examples/time_series_learning.png" },
275+
{ shell = "cp examples/time_series_generation.png docs/exported/examples/time_series_generation.png" },
273276
]
274277
help = "Get the git change log to. Next, create the badges. Finally, fetch all thrid party licenses and add them to the documentation."
275278

0 commit comments

Comments
 (0)