Skip to content

Commit

Permalink
examples.md
Browse files Browse the repository at this point in the history
  • Loading branch information
josemanuel22 committed Mar 1, 2024
1 parent 45bed92 commit e38faf7
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 3 deletions.
64 changes: 62 additions & 2 deletions docs/src/example.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
To make simple use, once the package is installed, just run the examples. For instance, execute,
To make simple use, once the package is installed, just run the examples.

# Learning 1-D distributions

```julia
# This example is from examples/Learning1d_distribution/benchmark_unimodal.jl
Expand Down Expand Up @@ -46,4 +48,62 @@ include("../utils.jl")
end
```

![Example Image](./imgs/readme_images_1.png)
![Example Image](./imgs/readme_images_1.png)


# Time Series

```
@test_experiments "testing AutoRegressive Model 1" begin
# --- Model Parameters and Data Generation ---
# Define AR model parameters
ar_hparams = ARParams(;
ϕ=[0.5f0, 0.3f0, 0.2f0], # Autoregressive coefficients
x₁=rand(Normal(0.0f0, 1.0f0)), # Initial value from a Normal distribution
proclen=2000, # Length of the process
noise=Normal(0.0f0, 0.2f0), # Noise in the AR process
)
# Define the recurrent and generative models
recurrent_model = Chain(RNN(1 => 10, relu), RNN(10 => 10, relu))
generative_model = Chain(Dense(11, 16, relu), Dense(16, 1, identity))
# Generate training and testing data
n_series = 200 # Number of series to generate
loaderXtrain, loaderYtrain, loaderXtest, loaderYtest = generate_batch_train_test_data(
n_series, ar_hparams
)
# --- Training Configuration ---
# Define hyperparameters for time series prediction
ts_hparams = HyperParamsTS(;
seed=1234,
η=1e-3, # Learning rate
epochs=n_series,
window_size=1000, # Size of the window for prediction
K=10, # Hyperparameter K (if it has a specific use, add a comment)
)
# Train model and calculate loss
loss = ts_invariant_statistical_loss_one_step_prediction(
recurrent_model, generative_model, loaderXtrain, loaderYtrain, ts_hparams
)
# --- Visualization ---
# Plotting the time series prediction
plot_univariate_ts_prediction(
recurrent_model,
generative_model,
collect(loaderXtrain)[2], # Extract the first batch for plotting
collect(loaderXtest)[2], # Extract the first batch for plotting
ts_hparams;
n_average=1000, # Number of predictions to average
)
end
```

![Example Image](./imgs/readme_images_2.png)

Binary file added docs/src/imgs/readme_images_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/time_series_predictions/benchmark_ts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ end

# Define a function to generate synthetic data.
"""
generate_synthetic(range)
generate_synthetic(range)
Generate synthetic time series data according to a composite time series model.
Expand Down

0 comments on commit e38faf7

Please sign in to comment.