Skip to content

Commit

Permalink
change readme
Browse files Browse the repository at this point in the history
  • Loading branch information
josemanuel22 committed Mar 2, 2024
1 parent dab46df commit c1c1767
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 13 deletions.
6 changes: 3 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ makedocs(;
modules=[ISL],
pages=[
"Home" => "index.md",
"GAN" => "gan.md",
"Example" => "example.md",
"Benchmark" => "benchmark.md",
"GAN" => "Gan.md",
"Example" => "Examples.md",
"DeepAR" => "DeepAR.md",
],
strict=false,
)
Expand Down
86 changes: 86 additions & 0 deletions docs/src/DeepAR.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# DeepAR Module Documentation

## Overview

This module implements the DeepAR model, a method for probabilistic forecasting with autoregressive recurrent networks. DeepAR is designed to model time series data with complex patterns and provide accurate probabilistic forecasts. This implementation is inspired by the approach described in the DeepAR paper and is adapted for use in Julia with Flux and StatsBase for neural network modeling and statistical operations, respectively. This module has subsequently been extracted into a separate repository, see https://github.com/josemanuel22/DeepAR.jl

## Installation

Before using this module, ensure that you have installed the required Julia packages: Flux, StatsBase, and Random. You can add these packages to your Julia environment by running:

```julia
using Pkg
Pkg.add(["Flux", "StatsBase", "Random"])
```

## Module Components

### DeepArParams Struct

A structure to hold hyperparameters for the DeepAR model.

```julia
Base.@kwdef mutable struct DeepArParams
η::Float64 = 1e-2 # Learning rate
epochs::Int = 10 # Number of training epochs
n_mean::Int = 100 # Number of samples for predictive mean
end
```

### train_DeepAR Function

Function to train a DeepAR model.

```julia
train_DeepAR(model, loaderXtrain, loaderYtrain, hparams) -> Vector{Float64}
```

- **Arguments**:
- `model`: The DeepAR model to be trained.
- `loaderXtrain`: DataLoader containing input sequences for training.
- `loaderYtrain`: DataLoader containing target sequences for training.
- `hparams`: An instance of `DeepArParams` specifying training hyperparameters.

- **Returns**: A vector of loss values recorded during training.

### forecasting_DeepAR Function

Function to generate forecasts using a trained DeepAR model.

```julia
forecasting_DeepAR(model, ts, t₀, τ; n_samples=100) -> Vector{Float32}
```

- **Arguments**:
- `model`: The trained DeepAR model.
- `ts`: Time series data for forecasting.
- `t₀`: Starting time step for forecasting.
- `τ`: Number of time steps to forecast.
- `n_samples`: Number of samples to draw for each forecast step (default: 100).

- **Returns**: A vector containing the forecasted values for each time step.

## Example Usage

This section demonstrates how to use the DeepAR model for probabilistic forecasting of time series data.

```julia
# Define AR model parameters and generate training and testing data
ar_hparams = ARParams(...)
loaderXtrain, loaderYtrain, loaderXtest, loaderYtest = generate_batch_train_test_data(...)

# Initialize the DeepAR model
model = Chain(...)

# Define hyperparameters and train the model
deepar_params = DeepArParams(...)
losses = train_DeepAR(model, loaderXtrain, loaderYtrain, deepar_params)

# Perform forecasting
t₀, τ = 100, 20
predictions = forecasting_DeepAR(model, collect(loaderXtrain)[1], t₀, τ; n_samples=100)
```

## References

- ["DeepAR: Probabilistic Forecasting with Autoregressive Recurrent Networks"](https://arxiv.org/pdf/1704.04110.pdf) by David Salinas, Valentin Flunkert, and Jan Gasthaus.
File renamed without changes.
19 changes: 19 additions & 0 deletions docs/src/Gans.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Generative Adversarial Networks (GANs) Module Overview

This repository includes a dedicated folder that contains implementations of different Generative Adversarial Networks (GANs), showcasing a variety of approaches within the GAN framework. Our collection includes:

- **Vanilla GAN**: Based on the foundational GAN concept introduced in ["Generative Adversarial Nets"](https://arxiv.org/pdf/1406.2661.pdf) by Goodfellow et al. This implementation adapts and modifies the code from [FluxGAN repository](https://github.com/AdarshKumar712/FluxGAN) to fit our testing needs.

- **WGAN (Wasserstein GAN)**: Implements the Wasserstein GAN as described in ["Wasserstein GAN"](https://arxiv.org/pdf/1701.07875.pdf) by Arjovsky et al., providing an advanced solution to the issue of training stability in GANs. Similar to Vanilla GAN, we have utilized and slightly adjusted the implementation from the [FluxGAN repository](https://github.com/AdarshKumar712/FluxGAN).

- **MMD-GAN (Maximum Mean Discrepancy GAN)**: Our implementation of MMD-GAN is inspired by the paper ["MMD GAN: Towards Deeper Understanding of Moment Matching Network"](https://arxiv.org/pdf/1705.08584.pdf) by Li et al. Unlike the previous models, the MMD-GAN implementation has been rewritten in Julia, transitioning from the original [Python code](https://github.com/OctoberChang/MMD-GAN) provided by the authors.

## Objective

The primary goal of incorporating these GAN models into our repository is to evaluate the effectiveness of ISL (Invariant Statistical Learning) methods as regularizers for GAN-based solutions. Specifically, we aim to address the challenges presented in the "Helvetica scenario," exploring how ISL methods can enhance the robustness and generalization of GANs in generating high-quality synthetic data.

## Implementation Details

For each GAN variant mentioned above, we have made certain adaptations to the original implementations to ensure compatibility with our testing framework and the objectives of the ISL method integration. These modifications range from architectural adjustments to the optimization process, aiming to optimize the performance and efficacy of the ISL regularizers within the GAN context.

We encourage interested researchers and practitioners to explore the implementations and consider the potential of ISL methods in improving GAN architectures. For more detailed insights into the modifications and specific implementation choices, please refer to the code and accompanying documentation within the respective folders for each GAN variant.
3 changes: 0 additions & 3 deletions docs/src/benchmark.md

This file was deleted.

7 changes: 0 additions & 7 deletions docs/src/gan.md

This file was deleted.

0 comments on commit c1c1767

Please sign in to comment.