Skip to content
/ delta Public
forked from blackportal-ai/delta

A Machine Learning Framework in Rust Δ

License

Notifications You must be signed in to change notification settings

chinmer/delta

 
 

Repository files navigation

Delta build

An open source Machine Learning Framework in Rust Δ.

Table of Contents

Desired API Usage

This is just a rough idea of what the API could look like. This is not the final API.

Example 1:

use delta_tensor::Tensor;
use delta_nn::layers::{Dense, Relu};
use delta_nn::models::Sequential;
use delta_optimizers::Adam;

fn main() {
    // Create a neural network
    let mut model = Sequential::new()
        .add(Dense::new(784, 128))   // Input: 784, Output: 128
        .add(Relu::new())            // Activation: ReLU
        .add(Dense::new(128, 10));   // Output: 10 classes

    // Define an optimizer
    let optimizer = Adam::new(learning_rate: 0.001);

    // Compile the model
    model.compile(optimizer);

    // Train the model
    let train_data = delta_datasets::mnist::load_train();
    let test_data = delta_datasets::mnist::load_test();

    model.fit(train_data, epochs: 10, batch_size: 32);

    // Evaluate the model
    let accuracy = model.evaluate(test_data);
    println!("Test Accuracy: {:.2}%", accuracy * 100.0);

    // Save the model
    model.save("model_path").unwrap();
}

Example two:

use delta_tensor::Tensor;
use delta_nn::layers::{Dense, Dropout, Relu, Softmax};
use delta_nn::models::Sequential;
use delta_optimizers::{Adam, LearningRateScheduler};
use delta_datasets::mnist::{self, Dataset};
use delta_callbacks::{EarlyStopping, ModelCheckpoint};

fn main() {
    // Load and preprocess the dataset
    let mut train_data = mnist::load_train();
    let mut test_data = mnist::load_test();

    // Data augmentation (example: normalize and add noise)
    train_data.normalize(0.0, 1.0); // Normalize to [0, 1]
    train_data.add_noise(0.05);     // Add Gaussian noise
    test_data.normalize(0.0, 1.0);

    // Create a neural network
    let mut model = Sequential::new()
        .add(Dense::new(784, 128))     // Input: 784, Output: 128
        .add(Relu::new())              // Activation: ReLU
        .add(Dropout::new(0.2))        // Dropout with 20% probability
        .add(Dense::new(128, 64))      // Intermediate layer
        .add(Relu::new())
        .add(Dense::new(64, 10))       // Output: 10 classes
        .add(Softmax::new());          // Output probabilities

    // Define an advanced optimizer with learning rate scheduling
    let mut optimizer = Adam::new(0.001);
    let scheduler = LearningRateScheduler::new(
        |epoch| if epoch < 5 { 0.001 } else { 0.0001 }
    );
    optimizer.set_scheduler(scheduler);

    // Compile the model
    model.compile(optimizer);

    // Define callbacks
    let early_stopping = EarlyStopping::new()
        .patience(3) // Stop training if no improvement for 3 epochs
        .monitor("val_accuracy");
    let checkpoint = ModelCheckpoint::new("best_model_path")
        .save_best_only(true) // Save only the best model
        .monitor("val_accuracy");

    // Train the model using a custom loop
    for epoch in 1..=10 {
        println!("Epoch {}/10", epoch);

        // Training step
        model.train(&train_data, 32);

        // Validation step
        let val_accuracy = model.validate(&test_data);
        println!("Validation Accuracy: {:.2}%", val_accuracy * 100.0);

        // Trigger callbacks
        early_stopping.step(val_accuracy);
        checkpoint.step(&model, val_accuracy);

        // Check for early stopping
        if early_stopping.should_stop() {
            println!("Early stopping triggered.");
            break;
        }
    }

    // Evaluate the final model
    let test_accuracy = model.evaluate(&test_data);
    println!("Final Test Accuracy: {:.2}%", test_accuracy * 100.0);

    // Save the final model
    model.save("final_model_path").unwrap();
}

Result of creating a model summary:

let model = Sequential::new()
    .add(Dense::new(784, 128))
    .add(Relu::new())
    .add(Dense::new(128, 10));

println!("{:#?}", model.summary());
Model Summary:
Layer (type)        Output Shape    Param #
-------------------------------------------
Dense               [None, 128]     100,480
Relu                [None, 128]     0
Dense               [None, 10]      1,290
===========================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0

Getting Help

Are you having trouble with Delta? We want to help!

  • If you are upgrading, read the release notes for upgrade instructions and "new and noteworthy" features.

  • Ask a question we monitor stackoverflow.com for questions tagged with delta-rs.

  • Report bugs with Delta at https://github.com/delta-rs/delta/issues.

Reporting Issues

Delta uses GitHub’s integrated issue tracking system to record bugs and feature requests. If you want to raise an issue, please follow the recommendations below:

  • Before you log a bug, please search the issue tracker to see if someone has already reported the problem.

  • If the issue doesn’t already exist, create a new issue.

  • Please provide as much information as possible with the issue report. We like to know the Delta version, operating system, and Rust version version you’re using.

  • If you need to paste code or include a stack trace, use Markdown. ``` escapes before and after your text.

  • If possible, try to create a test case or project that replicates the problem and attach it to the issue.

Contributing

Before contributing, please read the contribution guide for useful information how to get started with Delta as well as what should be included when submitting a contribution to the project.

Contributors

The following contributors have either helped to start this project, have contributed code, are actively maintaining it (including documentation), or in other ways being awesome contributors to this project. We'd like to take a moment to recognize them.

mjovanc

License

The BSD 3-Clause License.

About

A Machine Learning Framework in Rust Δ

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Rust 98.2%
  • Shell 1.8%