-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dab46df
commit c1c1767
Showing
6 changed files
with
108 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# DeepAR Module Documentation | ||
|
||
## Overview | ||
|
||
This module implements the DeepAR model, a method for probabilistic forecasting with autoregressive recurrent networks. DeepAR is designed to model time series data with complex patterns and provide accurate probabilistic forecasts. This implementation is inspired by the approach described in the DeepAR paper and is adapted for use in Julia with Flux and StatsBase for neural network modeling and statistical operations, respectively. This module has subsequently been extracted into a separate repository, see https://github.com/josemanuel22/DeepAR.jl | ||
|
||
## Installation | ||
|
||
Before using this module, ensure that you have installed the required Julia packages: Flux, StatsBase, and Random. You can add these packages to your Julia environment by running: | ||
|
||
```julia | ||
using Pkg | ||
Pkg.add(["Flux", "StatsBase", "Random"]) | ||
``` | ||
|
||
## Module Components | ||
|
||
### DeepArParams Struct | ||
|
||
A structure to hold hyperparameters for the DeepAR model. | ||
|
||
```julia | ||
Base.@kwdef mutable struct DeepArParams | ||
η::Float64 = 1e-2 # Learning rate | ||
epochs::Int = 10 # Number of training epochs | ||
n_mean::Int = 100 # Number of samples for predictive mean | ||
end | ||
``` | ||
|
||
### train_DeepAR Function | ||
|
||
Function to train a DeepAR model. | ||
|
||
```julia | ||
train_DeepAR(model, loaderXtrain, loaderYtrain, hparams) -> Vector{Float64} | ||
``` | ||
|
||
- **Arguments**: | ||
- `model`: The DeepAR model to be trained. | ||
- `loaderXtrain`: DataLoader containing input sequences for training. | ||
- `loaderYtrain`: DataLoader containing target sequences for training. | ||
- `hparams`: An instance of `DeepArParams` specifying training hyperparameters. | ||
|
||
- **Returns**: A vector of loss values recorded during training. | ||
|
||
### forecasting_DeepAR Function | ||
|
||
Function to generate forecasts using a trained DeepAR model. | ||
|
||
```julia | ||
forecasting_DeepAR(model, ts, t₀, τ; n_samples=100) -> Vector{Float32} | ||
``` | ||
|
||
- **Arguments**: | ||
- `model`: The trained DeepAR model. | ||
- `ts`: Time series data for forecasting. | ||
- `t₀`: Starting time step for forecasting. | ||
- `τ`: Number of time steps to forecast. | ||
- `n_samples`: Number of samples to draw for each forecast step (default: 100). | ||
|
||
- **Returns**: A vector containing the forecasted values for each time step. | ||
|
||
## Example Usage | ||
|
||
This section demonstrates how to use the DeepAR model for probabilistic forecasting of time series data. | ||
|
||
```julia | ||
# Define AR model parameters and generate training and testing data | ||
ar_hparams = ARParams(...) | ||
loaderXtrain, loaderYtrain, loaderXtest, loaderYtest = generate_batch_train_test_data(...) | ||
|
||
# Initialize the DeepAR model | ||
model = Chain(...) | ||
|
||
# Define hyperparameters and train the model | ||
deepar_params = DeepArParams(...) | ||
losses = train_DeepAR(model, loaderXtrain, loaderYtrain, deepar_params) | ||
|
||
# Perform forecasting | ||
t₀, τ = 100, 20 | ||
predictions = forecasting_DeepAR(model, collect(loaderXtrain)[1], t₀, τ; n_samples=100) | ||
``` | ||
|
||
## References | ||
|
||
- ["DeepAR: Probabilistic Forecasting with Autoregressive Recurrent Networks"](https://arxiv.org/pdf/1704.04110.pdf) by David Salinas, Valentin Flunkert, and Jan Gasthaus. |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Generative Adversarial Networks (GANs) Module Overview | ||
|
||
This repository includes a dedicated folder that contains implementations of different Generative Adversarial Networks (GANs), showcasing a variety of approaches within the GAN framework. Our collection includes: | ||
|
||
- **Vanilla GAN**: Based on the foundational GAN concept introduced in ["Generative Adversarial Nets"](https://arxiv.org/pdf/1406.2661.pdf) by Goodfellow et al. This implementation adapts and modifies the code from [FluxGAN repository](https://github.com/AdarshKumar712/FluxGAN) to fit our testing needs. | ||
|
||
- **WGAN (Wasserstein GAN)**: Implements the Wasserstein GAN as described in ["Wasserstein GAN"](https://arxiv.org/pdf/1701.07875.pdf) by Arjovsky et al., providing an advanced solution to the issue of training stability in GANs. Similar to Vanilla GAN, we have utilized and slightly adjusted the implementation from the [FluxGAN repository](https://github.com/AdarshKumar712/FluxGAN). | ||
|
||
- **MMD-GAN (Maximum Mean Discrepancy GAN)**: Our implementation of MMD-GAN is inspired by the paper ["MMD GAN: Towards Deeper Understanding of Moment Matching Network"](https://arxiv.org/pdf/1705.08584.pdf) by Li et al. Unlike the previous models, the MMD-GAN implementation has been rewritten in Julia, transitioning from the original [Python code](https://github.com/OctoberChang/MMD-GAN) provided by the authors. | ||
|
||
## Objective | ||
|
||
The primary goal of incorporating these GAN models into our repository is to evaluate the effectiveness of ISL (Invariant Statistical Learning) methods as regularizers for GAN-based solutions. Specifically, we aim to address the challenges presented in the "Helvetica scenario," exploring how ISL methods can enhance the robustness and generalization of GANs in generating high-quality synthetic data. | ||
|
||
## Implementation Details | ||
|
||
For each GAN variant mentioned above, we have made certain adaptations to the original implementations to ensure compatibility with our testing framework and the objectives of the ISL method integration. These modifications range from architectural adjustments to the optimization process, aiming to optimize the performance and efficacy of the ISL regularizers within the GAN context. | ||
|
||
We encourage interested researchers and practitioners to explore the implementations and consider the potential of ISL methods in improving GAN architectures. For more detailed insights into the modifications and specific implementation choices, please refer to the code and accompanying documentation within the respective folders for each GAN variant. |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.