diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..6177e4d --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,15 @@ +{ + "name": "C Development", + "image": "mcr.microsoft.com/devcontainers/cpp:1-debian-11", + "features": { + "ghcr.io/devcontainers/features/cpp:1": {} + }, + "customizations": { + "vscode": { + "extensions": ["ms-vscode.cpptools", "ms-vscode.cmake-tools"] + } + }, + "forwardPorts": [8000], + "postCreateCommand": "sudo apt-get update && sudo apt-get install -y libopenblas-dev liblapack-dev", + "remoteUser": "vscode" +} diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..0c5637f --- /dev/null +++ b/.gitattributes @@ -0,0 +1,4 @@ +* text=auto whitespace=trailing-space + +*.png binary +*.jpe?g binary diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..90e0456 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "devcontainers" + directory: "/" + schedule: + interval: weekly diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2db1416 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +# CMake & Testing +/build* +/tmp* + +# CLion +/cmake-build* + +# MacOS +.DS_Store + +# Others +.venv +__pycache__ +.pytest_cache +.cache +.CMake/a.out +compile_commands.json +scripts diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..f5a5d23 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,5 @@ +# Changelog + +## 0.0.1 - 2024-08-29 - @0xnu + +* Initial release diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..c95e64a --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +@0xnu diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..2150bda --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +f@finbarrs.eu. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..253002d --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,55 @@ +### Contributing + +1. Fork the repository +2. Create your feature branch: `git checkout -b my-new-feature` +3. Make your changes and ensure they follow the project's coding standards +4. Compile and test your changes locally +5. Commit your changes: `git commit -s -m 'Add some feature'` +6. Push to the branch: `git push origin my-new-feature` +7. Create a pull request + +### Building and Testing + +To build the project: +``` +make all +``` + +To run tests: +``` +make run-tests +``` + +For a full list of available make targets, run: +``` +make help +``` + +### Code Style + +Please follow the coding style guidelines outlined in the project. Typically for C projects, this might include: + +- Use 4 spaces for indentation (not tabs) +- Keep lines to a maximum of 80 characters +- Use descriptive variable and function names +- Comment your code where necessary + +### Submitting Changes + +- Ensure your code compiles without warnings +- Run all tests and make sure they pass +- Update documentation if you're changing functionality +- Include comments in your code where necessary + +**After your pull request is merged**, you can safely delete your branch. + +### Reporting Issues + +If you find a bug or have a suggestion for improvement, please open an issue on the project's issue tracker. Provide as much detail as possible, including: + +- Steps to reproduce the issue +- Expected behavior +- Actual behavior +- Your environment (compiler version, OS, etc.) + +Thank you for contributing to this project! diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f53522d --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) 2024, Finbarrs Oketunji. All Rights Reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..753dd9a --- /dev/null +++ b/Makefile @@ -0,0 +1,84 @@ +# qrme is a quantum-resistant encrypted machine learning system designed +# to protect sensitive data and models against potential threats from +# quantum computing. It utilises advanced cryptographic techniques to ensure +# the confidentiality and integrity of the machine learning model and its +# inputs/outputs, even in the face of future quantum attacks. +# +# Copyright (c) 2024 Finbarrs Oketunji +# Written by Finbarrs Oketunji +# +# This file is part of qrme. +# +# qrme is an open-source software: you are free to redistribute +# and/or modify it under the terms specified in version 3 of the GNU +# General Public License, as published by the Free Software Foundation. +# +# qrme is is made available with the hope that it will be beneficial, +# but it comes with NO WARRANTY whatsoever. This includes, but is not limited +# to, any implied warranties of MERCHANTABILITY or FITNESS FOR A PARTICULAR +# PURPOSE. For more detailed information, please refer to the +# GNU General Public License. +# +# You should have received a copy of the GNU General Public License +# along with qrme. If not, visit . + +# Detect the operating system +UNAME_S := $(shell uname -s) + +# Common variables +CC = gcc +CFLAGS = -O3 -I. +LDFLAGS = -loqs -lcrypto -lm + +# Source files +SRC = src/encryption.c src/model.c src/utils.c +OBJ = $(SRC:.c=.o) + +# Test files +TEST_SRC = tests/test_all.c +TEST_OBJ = $(TEST_SRC:.c=.o) + +# OS-specific configurations +ifeq ($(UNAME_S),Darwin) + # macOS configuration + LIBOQS_INCLUDE = -I/usr/local/include + LIBOQS_LIB = -L/usr/local/lib + # Install liboqs if not already installed + ifeq ($(shell brew list --formula | grep -q liboqs; echo $$?),1) + $(shell brew install liboqs) + endif +else + # Linux configuration + LIBOQS_INCLUDE = -I/usr/include + LIBOQS_LIB = -L/usr/lib + # Linux package installation + PACKAGES = gcc libssl-dev liboqs-dev + $(shell sudo apt-get update && sudo apt-get install -y $(PACKAGES)) +endif + +all: create_sample_model test_all ## Build all targets + +create_sample_model: create_sample_model.c $(OBJ) ## Build the sample model creation tool + $(CC) $(CFLAGS) $(LIBOQS_INCLUDE) -o $@ $^ $(LIBOQS_LIB) $(LDFLAGS) + +test_all: $(TEST_SRC) $(OBJ) ## Build the test runner + $(CC) $(CFLAGS) $(LIBOQS_INCLUDE) -o $@ $^ $(LIBOQS_LIB) $(LDFLAGS) + +%.o: %.c ## Compile object files + $(CC) $(CFLAGS) $(LIBOQS_INCLUDE) -c $< -o $@ + +run-sample: create_sample_model ## Run the sample model creation + ./create_sample_model + +run-tests: test_all ## Run all tests + ./test_all + +clean: ## Clean up build artifacts + rm -f $(OBJ) $(TEST_OBJ) create_sample_model test_all test_model.bin test_secret.key + +help: ## Display help message + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: all run-sample run-tests clean help + +.DEFAULT_GOAL := help diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000..1d43a4a --- /dev/null +++ b/NOTICE @@ -0,0 +1,6 @@ +@name: qrme +@author: Finbarrs Oketunji +@contact: f@finbarrs.eu +@time: 24/08/2024 - 21:30 + +This product includes software developed by [Finbarrs Oketunji](https://finbarrs.eu). diff --git a/README.md b/README.md new file mode 100644 index 0000000..6a28381 --- /dev/null +++ b/README.md @@ -0,0 +1,44 @@ +## qrme + +`qrme` is a quantum-resistant encrypted machine learning system designed to protect sensitive data and models against potential threats from [quantum computing](https://en.wikipedia.org/wiki/Quantum_computing). It utilises advanced cryptographic techniques to ensure the confidentiality and integrity of the machine learning model and its inputs/outputs, even in the face of future quantum attacks. + +### Usage + +You can start using `qrme` by executing: + +```sh +make help +``` + +There's a minimal integration example [here](./create_sample_model.c). Don't forget to implement secure methods for key distribution and storage and ensure the integrity of the model file in a production environment. + +### References + ++ [Quantum-Resistant Cryptography](https://arxiv.org/abs/2112.00399) ++ [Applications of Post-quantum Cryptography](https://arxiv.org/abs/2406.13258) ++ [Preparing for Quantum-Safe Cryptography](https://www.ncsc.gov.uk/whitepaper/preparing-for-quantum-safe-cryptography) ++ [Next Steps in Preparing for Post-Quantum Cryptography](https://www.ncsc.gov.uk/whitepaper/next-steps-preparing-for-post-quantum-cryptography) ++ [The Impact of Quantum Computing on Present Cryptography](https://arxiv.org/abs/1804.00200) ++ [Post-Quantum Cryptography for Internet of Things: A Survey on Performance and Optimization](https://arxiv.org/abs/2401.17538) + +### License + +This project is licensed under the [BSD 3-Clause](LICENSE) License. + +### Citation + +```tex +@misc{qrme, + author = {Oketunji, A.F.}, + title = {Quantum-Resistant Model Encryption (QRME)}, + year = 2024, + version = {0.0.1}, + publisher = {Zenodo}, + doi = {10.5281/zenodo.13449375}, + url = {https://doi.org/10.5281/zenodo.13449375} +} +``` + +### Copyright + +(c) 2024 [Finbarrs Oketunji](https://finbarrs.eu). All Rights Reserved. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..9f09c7e --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,5 @@ +### Security Policy + +#### Reporting a Vulnerability + +Please report security issues to [0xnu](mailto:f@finbarrs.eu) diff --git a/create_sample_model.c b/create_sample_model.c new file mode 100644 index 0000000..1007939 --- /dev/null +++ b/create_sample_model.c @@ -0,0 +1,137 @@ +#include +#include +#include +#include "include/model.h" +#include "include/encryption.h" +#include "include/utils.h" + +#define TEST_MODEL_FILE "test_model.bin" +#define TEST_SECRET_KEY_FILE "test_secret.key" +#define INPUT_SIZE 3 +#define HIDDEN_SIZE 2 +#define OUTPUT_SIZE 1 + +int main() { + printf("Creating sample encrypted model...\n"); + + // Initialize modules + init_encryption(); + init_random(); + + // Create a new model + Model* model = create_model(); + printf("Model created at %p\n", (void*)model); + if (model == NULL) { + fprintf(stderr, "Failed to create model: %s\n", get_model_error()); + return 1; + } + + // Add first layer (INPUT_SIZE -> HIDDEN_SIZE) + float* weights1 = generate_random_float_array(INPUT_SIZE * HIDDEN_SIZE, -1.0f, 1.0f); + printf("Weights1 allocated at %p\n", (void*)weights1); + if (weights1 == NULL) { + fprintf(stderr, "Failed to generate weights for first layer: %s\n", get_utils_error()); + free_model(model); + return 1; + } + + if (add_layer(model, weights1, HIDDEN_SIZE, INPUT_SIZE) != 0) { + fprintf(stderr, "Failed to add first layer: %s\n", get_model_error()); + free(weights1); + free_model(model); + return 1; + } + // Do not free weights1 here, as add_layer has taken ownership of it + + // Add second layer (HIDDEN_SIZE -> OUTPUT_SIZE) + float* weights2 = generate_random_float_array(HIDDEN_SIZE * OUTPUT_SIZE, -1.0f, 1.0f); + printf("Weights2 allocated at %p\n", (void*)weights2); + if (weights2 == NULL) { + fprintf(stderr, "Failed to generate weights for second layer: %s\n", get_utils_error()); + free_model(model); + return 1; + } + + if (add_layer(model, weights2, OUTPUT_SIZE, HIDDEN_SIZE) != 0) { + fprintf(stderr, "Failed to add second layer: %s\n", get_model_error()); + free(weights2); + free_model(model); + return 1; + } + // Do not free weights2 here, as add_layer has taken ownership of it + + // Generate a key pair for encryption + uint8_t *public_key, *secret_key; + size_t public_key_len, secret_key_len; + if (generate_keypair(&public_key, &public_key_len, &secret_key, &secret_key_len) != 0) { + fprintf(stderr, "Failed to generate key pair: %s\n", get_error()); + free_model(model); + return 1; + } + printf("Public key allocated at %p\n", (void*)public_key); + printf("Secret key allocated at %p\n", (void*)secret_key); + + // Save the model + if (save_model(model, TEST_MODEL_FILE, public_key, public_key_len) != 0) { + fprintf(stderr, "Failed to save model: %s\n", get_model_error()); + free_model(model); + cleanup(public_key); + cleanup(secret_key); + return 1; + } + + printf("Sample encrypted model created and saved as %s\n", TEST_MODEL_FILE); + printf("Model structure:\n"); + printf(" Input size: %d\n", INPUT_SIZE); + printf(" Hidden layer size: %d\n", HIDDEN_SIZE); + printf(" Output size: %d\n", OUTPUT_SIZE); + + // Save the secret key to a separate file for testing purposes + FILE* key_file = fopen(TEST_SECRET_KEY_FILE, "wb"); + if (key_file == NULL) { + fprintf(stderr, "Failed to create secret key file\n"); + } else { + fwrite(&secret_key_len, sizeof(size_t), 1, key_file); + fwrite(secret_key, 1, secret_key_len, key_file); + fclose(key_file); + printf("Secret key saved as %s\n", TEST_SECRET_KEY_FILE); + } + + // Test loading and inference + Model* loaded_model = load_model(TEST_MODEL_FILE, secret_key, secret_key_len); + printf("Loaded model at %p\n", (void*)loaded_model); + if (loaded_model == NULL) { + fprintf(stderr, "Failed to load model: %s\n", get_model_error()); + } else { + printf("Successfully loaded the model.\n"); + + // Perform a test inference + float test_input[INPUT_SIZE] = {0.5f, -0.3f, 0.8f}; + float test_output[OUTPUT_SIZE]; + + if (inference(loaded_model, test_input, INPUT_SIZE, test_output, OUTPUT_SIZE) == 0) { + printf("Test inference result: %f\n", test_output[0]); + } else { + fprintf(stderr, "Failed to perform inference: %s\n", get_model_error()); + } + + printf("Freeing loaded model at %p\n", (void*)loaded_model); + free_model(loaded_model); + } + + printf("Starting cleanup...\n"); + + // Clean up + printf("Freeing original model at %p\n", (void*)model); + free_model(model); + printf("Cleaning up public key at %p\n", (void*)public_key); + cleanup(public_key); + printf("Cleaning up secret key at %p\n", (void*)secret_key); + cleanup(secret_key); + printf("Cleaning up encryption...\n"); + cleanup_encryption(); + + printf("Cleanup complete.\n"); + + return 0; +} diff --git a/include/encryption.h b/include/encryption.h new file mode 100644 index 0000000..87d18df --- /dev/null +++ b/include/encryption.h @@ -0,0 +1,83 @@ +#ifndef ENCRYPTION_H +#define ENCRYPTION_H + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Generate a quantum-resistant key pair + * + * @param public_key Pointer to store the public key + * @param public_key_len Pointer to store the length of the public key + * @param secret_key Pointer to store the secret key + * @param secret_key_len Pointer to store the length of the secret key + * @return 0 on success, -1 on failure + */ +int generate_keypair(uint8_t **public_key, size_t *public_key_len, + uint8_t **secret_key, size_t *secret_key_len); + +/** + * Encrypt data using CRYSTALS-Kyber and AES-256-GCM + * + * @param public_key The public key + * @param public_key_len Length of the public key + * @param plaintext The data to encrypt + * @param plaintext_len Length of the plaintext + * @param ciphertext Pointer to store the encrypted data + * @param ciphertext_len Pointer to store the length of the ciphertext + * @return 0 on success, -1 on failure + */ +int encrypt(const uint8_t *public_key, size_t public_key_len, + const uint8_t *plaintext, size_t plaintext_len, + uint8_t **ciphertext, size_t *ciphertext_len); + +/** + * Decrypt data using CRYSTALS-Kyber and AES-256-GCM + * + * @param secret_key The secret key + * @param secret_key_len Length of the secret key + * @param ciphertext The data to decrypt + * @param ciphertext_len Length of the ciphertext + * @param plaintext Pointer to store the decrypted data + * @param plaintext_len Pointer to store the length of the plaintext + * @return 0 on success, -1 on failure + */ +int decrypt(const uint8_t *secret_key, size_t secret_key_len, + const uint8_t *ciphertext, size_t ciphertext_len, + uint8_t **plaintext, size_t *plaintext_len); + +/** + * Clean up and free memory + * + * @param ptr Pointer to the memory to free + */ +void cleanup(void *ptr); + +/** + * Initialize the encryption module + * This function should be called once at the start of the program + */ +void init_encryption(void); + +/** + * Cleanup the encryption module + * This function should be called once at the end of the program + */ +void cleanup_encryption(void); + +/** + * Get the last error message from the encryption module + * + * @return The last error message + */ +const char* get_error(void); + +#ifdef __cplusplus +} +#endif + +#endif /* ENCRYPTION_H */ diff --git a/include/model.h b/include/model.h new file mode 100644 index 0000000..87cbc51 --- /dev/null +++ b/include/model.h @@ -0,0 +1,107 @@ +#ifndef MODEL_H +#define MODEL_H + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define MAX_LAYERS 10 + +typedef struct { + float* weights; + size_t rows; + size_t cols; + int is_secure_allocated; +} Layer; + +typedef struct Model { + Layer layers[MAX_LAYERS]; + size_t num_layers; + uint8_t* public_key; + size_t public_key_len; +} Model; + +/** + * Create a new empty model + * + * @return A pointer to the new Model, or NULL on failure + */ +Model* create_model(void); + +/** + * Add a layer to the model + * + * @param model The model to add the layer to + * @param weights The weights of the layer + * @param rows The number of rows in the weight matrix + * @param cols The number of columns in the weight matrix + * @return 0 on success, -1 on failure + */ +int add_layer(Model* model, const float* weights, size_t rows, size_t cols); + +/** + * Save a model to a file + * + * @param model The model to save + * @param filename The name of the file to save the model to + * @param public_key The public key to encrypt the model + * @param public_key_len The length of the public key + * @return 0 on success, -1 on failure + */ +int save_model(const Model* model, const char* filename, const uint8_t* public_key, size_t public_key_len); + +/** + * Load an encrypted model from a file + * + * @param filename The name of the file containing the encrypted model + * @param secret_key The secret key to decrypt the model + * @param secret_key_len The length of the secret key + * @return A pointer to the loaded Model, or NULL on failure + */ +Model* load_model(const char* filename, const uint8_t* secret_key, size_t secret_key_len); + +/** + * Perform inference using the model + * + * @param model The model to use for inference + * @param input The input data + * @param input_size The size of the input data + * @param output The output data (must be pre-allocated) + * @param output_size The size of the output data + * @return 0 on success, -1 on failure + */ +int inference(const Model* model, const float* input, size_t input_size, + float* output, size_t output_size); + +/** + * Free the memory used by a model + * + * @param model The model to free + */ +void free_model(Model* model); + +/** + * Get the public key of the model + * + * @param model The model + * @param public_key Pointer to store the public key + * @param public_key_len Pointer to store the length of the public key + * @return 0 on success, -1 on failure + */ +int get_model_public_key(const Model* model, const uint8_t** public_key, size_t* public_key_len); + +/** + * Get the last error message from the model module + * + * @return The last error message + */ +const char* get_model_error(void); + +#ifdef __cplusplus +} +#endif + +#endif /* MODEL_H */ diff --git a/include/utils.h b/include/utils.h new file mode 100644 index 0000000..7cfe673 --- /dev/null +++ b/include/utils.h @@ -0,0 +1,253 @@ +#ifndef UTILS_H +#define UTILS_H + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Securely allocate memory and zero it out + * + * @param size The number of bytes to allocate + * @return A pointer to the allocated memory, or NULL on failure + */ +void* secure_malloc(size_t size); + +/** + * Securely free memory and set pointer to NULL + * + * @param ptr Pointer to the memory to free + */ +void secure_free(void** ptr); + +/** + * Convert a float array to a byte array + * + * @param float_array The input float array + * @param float_array_len The length of the float array + * @param byte_array Pointer to store the output byte array + * @param byte_array_len Pointer to store the length of the byte array + * @return 0 on success, -1 on failure + */ +int float_to_byte_array(const float* float_array, size_t float_array_len, + uint8_t** byte_array, size_t* byte_array_len); + +/** + * Convert a byte array to a float array + * + * @param byte_array The input byte array + * @param byte_array_len The length of the byte array + * @param float_array Pointer to store the output float array + * @param float_array_len Pointer to store the length of the float array + * @return 0 on success, -1 on failure + */ +int byte_to_float_array(const uint8_t* byte_array, size_t byte_array_len, + float** float_array, size_t* float_array_len); + +/** + * Generate a random float array + * + * @param len The length of the array to generate + * @param min The minimum value (inclusive) + * @param max The maximum value (exclusive) + * @return A pointer to the generated array, or NULL on failure + */ +float* generate_random_float_array(size_t len, float min, float max); + +/** + * Compute the dot product of two float arrays + * + * @param a The first array + * @param b The second array + * @param len The length of both arrays + * @return The dot product + */ +float dot_product(const float* a, const float* b, size_t len); + +/** + * Perform element-wise addition of two float arrays + * + * @param a The first array + * @param b The second array + * @param result The array to store the result + * @param len The length of all arrays + */ +void vector_add(const float* a, const float* b, float* result, size_t len); + +/** + * Perform element-wise subtraction of two float arrays + * + * @param a The first array + * @param b The second array + * @param result The array to store the result + * @param len The length of all arrays + */ +void vector_subtract(const float* a, const float* b, float* result, size_t len); + +/** + * Compute the L2 norm (Euclidean norm) of a float array + * + * @param a The input array + * @param len The length of the array + * @return The L2 norm + */ +float l2_norm(const float* a, size_t len); + +/** + * Normalize a float array to have unit L2 norm + * + * @param a The input array + * @param result The array to store the normalized result + * @param len The length of both arrays + */ +void normalize(const float* a, float* result, size_t len); + +/** + * Compute the softmax of a float array + * + * @param a The input array + * @param result The array to store the softmax result + * @param len The length of both arrays + */ +void softmax(const float* a, float* result, size_t len); + +/** + * Print a float array + * + * @param array The array to print + * @param len The length of the array + * @param name The name of the array (for display purposes) + */ +void print_float_array(const float* array, size_t len, const char* name); + +/** + * Save a float array to a file + * + * @param filename The name of the file to save to + * @param array The array to save + * @param len The length of the array + * @return 0 on success, -1 on failure + */ +int save_float_array(const char* filename, const float* array, size_t len); + +/** + * Load a float array from a file + * + * @param filename The name of the file to load from + * @param array Pointer to store the loaded array + * @param len Pointer to store the length of the loaded array + * @return 0 on success, -1 on failure + */ +int load_float_array(const char* filename, float** array, size_t* len); + +/** + * Initialize the random number generator + * This function should be called once at the start of the program + */ +void init_random(void); + +/** + * Compare two float values with a small epsilon to account for floating-point imprecision + * + * @param a The first float + * @param b The second float + * @param epsilon The maximum difference to consider the floats equal + * @return 1 if the floats are equal within epsilon, 0 otherwise + */ +int float_equal(float a, float b, float epsilon); + +/** + * Clip a float value to a specified range + * + * @param value The value to clip + * @param min The minimum allowed value + * @param max The maximum allowed value + * @return The clipped value + */ +float clip(float value, float min, float max); + +/** + * Compute the sigmoid of a float value + * + * @param x The input value + * @return The sigmoid of x + */ +float sigmoid(float x); + +/** + * Apply the sigmoid function element-wise to a float array + * + * @param a The input array + * @param result The array to store the result + * @param len The length of both arrays + */ +void sigmoid_array(const float* a, float* result, size_t len); + +/** + * Compute the hyperbolic tangent (tanh) of a float value + * + * @param x The input value + * @return The tanh of x + */ +float tanh_float(float x); + +/** + * Apply the tanh function element-wise to a float array + * + * @param a The input array + * @param result The array to store the result + * @param len The length of both arrays + */ +void tanh_array(const float* a, float* result, size_t len); + +/** + * Compute the ReLU (Rectified Linear Unit) of a float value + * + * @param x The input value + * @return The ReLU of x + */ +float relu(float x); + +/** + * Apply the ReLU function element-wise to a float array + * + * @param a The input array + * @param result The array to store the result + * @param len The length of both arrays + */ +void relu_array(const float* a, float* result, size_t len); + +/** + * Compute the leaky ReLU of a float value + * + * @param x The input value + * @param alpha The slope for negative values + * @return The leaky ReLU of x + */ +float leaky_relu(float x, float alpha); + +/** + * Apply the leaky ReLU function element-wise to a float array + * + * @param a The input array + * @param result The array to store the result + * @param len The length of both arrays + * @param alpha The slope for negative values + */ +void leaky_relu_array(const float* a, float* result, size_t len, float alpha); + +/** + * Get the last error message from the utils module + * + * @return The last error message + */ +const char* get_utils_error(void); + +#ifdef __cplusplus +} +#endif + +#endif /* UTILS_H */ diff --git a/src/encryption.c b/src/encryption.c new file mode 100644 index 0000000..0727cf4 --- /dev/null +++ b/src/encryption.c @@ -0,0 +1,289 @@ +#include +#include +#include +#include +#include +#include +#include +#include "../include/encryption.h" +#include "../include/utils.h" + +#define MAX_ERROR_LENGTH 256 +#define AES_256_KEY_SIZE 32 +#define GCM_IV_SIZE 12 +#define GCM_TAG_SIZE 16 + +// Global error state +static char error_message[MAX_ERROR_LENGTH] = {0}; + +typedef struct { + size_t size; + char data[]; +} secure_alloc_t; + +// Function to set error message +static void set_error(const char* message) { + strncpy(error_message, message, MAX_ERROR_LENGTH - 1); + error_message[MAX_ERROR_LENGTH - 1] = '\0'; +} + +// Function to get error message +const char* get_error() { + return error_message; +} + +int generate_keypair(uint8_t **public_key, size_t *public_key_len, + uint8_t **secret_key, size_t *secret_key_len) { + OQS_KEM *kem = NULL; + int ret = -1; + + kem = OQS_KEM_new(OQS_KEM_alg_kyber_768); + if (kem == NULL) { + set_error("Error creating KEM instance"); + return ret; + } + + *public_key = secure_malloc(kem->length_public_key); + *secret_key = secure_malloc(kem->length_secret_key); + + if (!*public_key || !*secret_key) { + set_error("Error allocating memory for keys"); + goto cleanup; + } + + if (OQS_KEM_keypair(kem, *public_key, *secret_key) != OQS_SUCCESS) { + set_error("Error generating keypair"); + goto cleanup; + } + + *public_key_len = kem->length_public_key; + *secret_key_len = kem->length_secret_key; + + ret = 0; // Success + +cleanup: + if (ret != 0) { + secure_free((void**)public_key); + secure_free((void**)secret_key); + } + OQS_KEM_free(kem); + return ret; +} + +int encrypt(const uint8_t *public_key, size_t public_key_len, + const uint8_t *plaintext, size_t plaintext_len, + uint8_t **ciphertext, size_t *ciphertext_len) { + OQS_KEM *kem = NULL; + EVP_CIPHER_CTX *ctx = NULL; + uint8_t *kem_ciphertext = NULL; + uint8_t *shared_secret = NULL; + uint8_t *aes_ciphertext = NULL; + uint8_t iv[GCM_IV_SIZE]; + uint8_t tag[GCM_TAG_SIZE]; + int len, aes_ciphertext_len; + int ret = -1; + + kem = OQS_KEM_new(OQS_KEM_alg_kyber_768); + if (kem == NULL) { + set_error("Error creating KEM instance"); + return ret; + } + + if (public_key_len != kem->length_public_key) { + set_error("Invalid public key length"); + goto cleanup; + } + + kem_ciphertext = secure_malloc(kem->length_ciphertext); + shared_secret = secure_malloc(kem->length_shared_secret); + aes_ciphertext = secure_malloc(plaintext_len + EVP_MAX_BLOCK_LENGTH); + + if (!kem_ciphertext || !shared_secret || !aes_ciphertext) { + set_error("Error allocating memory"); + goto cleanup; + } + + if (OQS_KEM_encaps(kem, kem_ciphertext, shared_secret, public_key) != OQS_SUCCESS) { + set_error("Error in KEM encapsulation"); + goto cleanup; + } + + // Generate a random IV + if (RAND_bytes(iv, GCM_IV_SIZE) != 1) { + set_error("Error generating random IV"); + goto cleanup; + } + + // Create and initialise the context + if (!(ctx = EVP_CIPHER_CTX_new())) { + set_error("Error creating cipher context"); + goto cleanup; + } + + // Initialise the encryption operation + if (EVP_EncryptInit_ex(ctx, EVP_aes_256_gcm(), NULL, shared_secret, iv) != 1) { + set_error("Error initializing encryption"); + goto cleanup; + } + + // Encrypt plaintext + if (EVP_EncryptUpdate(ctx, aes_ciphertext, &len, plaintext, plaintext_len) != 1) { + set_error("Error in encryption update"); + goto cleanup; + } + aes_ciphertext_len = len; + + // Finalize encryption + if (EVP_EncryptFinal_ex(ctx, aes_ciphertext + len, &len) != 1) { + set_error("Error finalizing encryption"); + goto cleanup; + } + aes_ciphertext_len += len; + + // Get the tag + if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, GCM_TAG_SIZE, tag) != 1) { + set_error("Error getting tag"); + goto cleanup; + } + + // Assemble final ciphertext: KEM ciphertext + IV + AES ciphertext + tag + *ciphertext_len = kem->length_ciphertext + GCM_IV_SIZE + aes_ciphertext_len + GCM_TAG_SIZE; + *ciphertext = secure_malloc(*ciphertext_len); + if (!*ciphertext) { + set_error("Error allocating memory for final ciphertext"); + goto cleanup; + } + + memcpy(*ciphertext, kem_ciphertext, kem->length_ciphertext); + memcpy(*ciphertext + kem->length_ciphertext, iv, GCM_IV_SIZE); + memcpy(*ciphertext + kem->length_ciphertext + GCM_IV_SIZE, aes_ciphertext, aes_ciphertext_len); + memcpy(*ciphertext + kem->length_ciphertext + GCM_IV_SIZE + aes_ciphertext_len, tag, GCM_TAG_SIZE); + + ret = 0; // Success + +cleanup: + if (ctx) EVP_CIPHER_CTX_free(ctx); + secure_free((void**)&kem_ciphertext); + secure_free((void**)&shared_secret); + secure_free((void**)&aes_ciphertext); + OQS_KEM_free(kem); + return ret; +} + +int decrypt(const uint8_t *secret_key, size_t secret_key_len, + const uint8_t *ciphertext, size_t ciphertext_len, + uint8_t **plaintext, size_t *plaintext_len) { + OQS_KEM *kem = NULL; + EVP_CIPHER_CTX *ctx = NULL; + uint8_t *shared_secret = NULL; + uint8_t *aes_ciphertext = NULL; + uint8_t iv[GCM_IV_SIZE]; + uint8_t tag[GCM_TAG_SIZE]; + int len; + int ret = -1; + + kem = OQS_KEM_new(OQS_KEM_alg_kyber_768); + if (kem == NULL) { + set_error("Error creating KEM instance"); + return ret; + } + + if (secret_key_len != kem->length_secret_key || + ciphertext_len <= kem->length_ciphertext + GCM_IV_SIZE + GCM_TAG_SIZE) { + set_error("Invalid key or ciphertext length"); + goto cleanup; + } + + shared_secret = secure_malloc(kem->length_shared_secret); + if (!shared_secret) { + set_error("Error allocating memory"); + goto cleanup; + } + + // Decapsulate to get the shared secret + if (OQS_KEM_decaps(kem, shared_secret, ciphertext, secret_key) != OQS_SUCCESS) { + set_error("Error in KEM decapsulation"); + goto cleanup; + } + + // Extract IV and tag from ciphertext + memcpy(iv, ciphertext + kem->length_ciphertext, GCM_IV_SIZE); + memcpy(tag, ciphertext + ciphertext_len - GCM_TAG_SIZE, GCM_TAG_SIZE); + + // Create and initialise the context + if (!(ctx = EVP_CIPHER_CTX_new())) { + set_error("Error creating cipher context"); + goto cleanup; + } + + // Initialise the decryption operation + if (EVP_DecryptInit_ex(ctx, EVP_aes_256_gcm(), NULL, shared_secret, iv) != 1) { + set_error("Error initializing decryption"); + goto cleanup; + } + + // Set expected tag value + if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, GCM_TAG_SIZE, (void*)tag) != 1) { + set_error("Error setting tag"); + goto cleanup; + } + + // Allocate memory for plaintext + size_t aes_ciphertext_len = ciphertext_len - kem->length_ciphertext - GCM_IV_SIZE - GCM_TAG_SIZE; + *plaintext = secure_malloc(aes_ciphertext_len); + if (!*plaintext) { + set_error("Error allocating memory for plaintext"); + goto cleanup; + } + + // Decrypt ciphertext + if (EVP_DecryptUpdate(ctx, *plaintext, &len, + ciphertext + kem->length_ciphertext + GCM_IV_SIZE, + aes_ciphertext_len) != 1) { + set_error("Error in decryption update"); + goto cleanup; + } + *plaintext_len = len; + + // Finalize decryption + if (EVP_DecryptFinal_ex(ctx, *plaintext + len, &len) != 1) { + set_error("Error finalizing decryption"); + goto cleanup; + } + *plaintext_len += len; + + ret = 0; // Success + +cleanup: + if (ctx) EVP_CIPHER_CTX_free(ctx); + secure_free((void**)&shared_secret); + OQS_KEM_free(kem); + if (ret != 0 && *plaintext) { + secure_free((void**)plaintext); + } + return ret; +} + +void cleanup(void *ptr) { + if (ptr) { + // Check if the pointer was allocated by our secure_malloc + secure_alloc_t* alloc = (secure_alloc_t*)((char*)ptr - offsetof(secure_alloc_t, data)); + if (alloc->size > 0) { + // This was allocated by our secure_malloc, so use secure_free + secure_free(&ptr); + } else { + // This was likely allocated by the OQS library, so use regular free + free(ptr); + } + } +} + +void init_encryption() { + OpenSSL_add_all_algorithms(); + ERR_load_crypto_strings(); +} + +void cleanup_encryption() { + EVP_cleanup(); + ERR_free_strings(); +} diff --git a/src/main.c b/src/main.c new file mode 100644 index 0000000..0c0c36e --- /dev/null +++ b/src/main.c @@ -0,0 +1,209 @@ +#include +#include +#include +#include "../include/encryption.h" +#include "../include/model.h" +#include "../include/utils.h" + +#define INPUT_SIZE 784 // MNIST-like input size +#define OUTPUT_SIZE 10 // 10 classes for classification + +void print_usage(const char* program_name) { + printf("Usage: %s \n", program_name); +} + +int main(int argc, char* argv[]) { + if (argc != 3) { + print_usage(argv[0]); + return 1; + } + + const char* model_file = argv[1]; + const char* secret_key_file = argv[2]; + + // Initialize the encryption and utility modules + init_encryption(); + init_random(); + + // Load the secret key + uint8_t* secret_key; + size_t secret_key_len; + FILE* key_file = fopen(secret_key_file, "rb"); + if (!key_file) { + fprintf(stderr, "Error: Unable to open secret key file.\n"); + return 1; + } + fseek(key_file, 0, SEEK_END); + secret_key_len = ftell(key_file); + fseek(key_file, 0, SEEK_SET); + secret_key = malloc(secret_key_len); + if (!secret_key || fread(secret_key, 1, secret_key_len, key_file) != secret_key_len) { + fprintf(stderr, "Error: Unable to read secret key.\n"); + fclose(key_file); + return 1; + } + fclose(key_file); + + // Load the encrypted model + Model* model = load_model(model_file, secret_key, secret_key_len); + if (!model) { + fprintf(stderr, "Error: %s\n", get_model_error()); + free(secret_key); + return 1; + } + + // Get the model's public key + const uint8_t* public_key; + size_t public_key_len; + if (get_model_public_key(model, &public_key, &public_key_len) != 0) { + fprintf(stderr, "Error: Unable to get model's public key.\n"); + free_model(model); + free(secret_key); + return 1; + } + + // Generate a random input (simulating an image) + float* input = generate_random_float_array(INPUT_SIZE, 0.0f, 1.0f); + if (!input) { + fprintf(stderr, "Error: %s\n", get_utils_error()); + free_model(model); + free(secret_key); + return 1; + } + + printf("Generated random input:\n"); + print_float_array(input, 10, "First 10 elements"); // Print first 10 elements for conciseness + + // Encrypt the input + uint8_t* input_bytes; + size_t input_bytes_len; + if (float_to_byte_array(input, INPUT_SIZE, &input_bytes, &input_bytes_len) != 0) { + fprintf(stderr, "Error: %s\n", get_utils_error()); + free(input); + free_model(model); + free(secret_key); + return 1; + } + + uint8_t* encrypted_input; + size_t encrypted_input_len; + if (encrypt(public_key, public_key_len, input_bytes, input_bytes_len, &encrypted_input, &encrypted_input_len) != 0) { + fprintf(stderr, "Error: %s\n", get_error()); + free(input_bytes); + free(input); + free_model(model); + free(secret_key); + return 1; + } + + free(input_bytes); + + // Decrypt the input (simulating what would happen on the server) + uint8_t* decrypted_input; + size_t decrypted_input_len; + if (decrypt(secret_key, secret_key_len, encrypted_input, encrypted_input_len, &decrypted_input, &decrypted_input_len) != 0) { + fprintf(stderr, "Error: %s\n", get_error()); + free(encrypted_input); + free(input); + free_model(model); + free(secret_key); + return 1; + } + + free(encrypted_input); + + // Convert decrypted input back to float array + float* decrypted_input_float; + size_t decrypted_input_float_len; + if (byte_to_float_array(decrypted_input, decrypted_input_len, &decrypted_input_float, &decrypted_input_float_len) != 0) { + fprintf(stderr, "Error: %s\n", get_utils_error()); + free(decrypted_input); + free(input); + free_model(model); + free(secret_key); + return 1; + } + + free(decrypted_input); + + // Perform inference + float* output = malloc(OUTPUT_SIZE * sizeof(float)); + if (!output) { + fprintf(stderr, "Error: Unable to allocate memory for output.\n"); + free(decrypted_input_float); + free(input); + free_model(model); + free(secret_key); + return 1; + } + + if (inference(model, decrypted_input_float, INPUT_SIZE, output, OUTPUT_SIZE) != 0) { + fprintf(stderr, "Error: %s\n", get_model_error()); + free(output); + free(decrypted_input_float); + free(input); + free_model(model); + free(secret_key); + return 1; + } + + free(decrypted_input_float); + + // Print the output + printf("Model output:\n"); + print_float_array(output, OUTPUT_SIZE, "Probabilities"); + + // Find the highest probability class + float max_prob = output[0]; + int max_class = 0; + for (int i = 1; i < OUTPUT_SIZE; i++) { + if (output[i] > max_prob) { + max_prob = output[i]; + max_class = i; + } + } + + printf("Predicted class: %d (probability: %.4f)\n", max_class, max_prob); + + // Encrypt the output (simulating sending the result back to the client) + uint8_t* output_bytes; + size_t output_bytes_len; + if (float_to_byte_array(output, OUTPUT_SIZE, &output_bytes, &output_bytes_len) != 0) { + fprintf(stderr, "Error: %s\n", get_utils_error()); + free(output); + free(input); + free_model(model); + free(secret_key); + return 1; + } + + uint8_t* encrypted_output; + size_t encrypted_output_len; + if (encrypt(public_key, public_key_len, output_bytes, output_bytes_len, &encrypted_output, &encrypted_output_len) != 0) { + fprintf(stderr, "Error: %s\n", get_error()); + free(output_bytes); + free(output); + free(input); + free_model(model); + free(secret_key); + return 1; + } + + free(output_bytes); + + printf("Encrypted output length: %zu bytes\n", encrypted_output_len); + + // Clean up + free(encrypted_output); + free(output); + free(input); + free_model(model); + free(secret_key); + + // Clean up the encryption and utility modules + cleanup_encryption(); + + printf("Quantum-resistant encrypted inference completed successfully.\n"); + + return 0; +} diff --git a/src/model.c b/src/model.c new file mode 100644 index 0000000..81ddfcd --- /dev/null +++ b/src/model.c @@ -0,0 +1,298 @@ +#include +#include +#include +#include "../include/model.h" +#include "../include/encryption.h" +#include "../include/utils.h" + +#define MAX_ERROR_LENGTH 256 + +static char error_message[MAX_ERROR_LENGTH] = {0}; + +static void set_error(const char* message) { + strncpy(error_message, message, MAX_ERROR_LENGTH - 1); + error_message[MAX_ERROR_LENGTH - 1] = '\0'; +} + +const char* get_model_error(void) { + return error_message; +} + +Model* create_model(void) { + Model* model = secure_malloc(sizeof(Model)); + if (!model) { + set_error("Failed to allocate memory for model"); + return NULL; + } + memset(model, 0, sizeof(Model)); + return model; +} + +int add_layer(Model* model, const float* weights, size_t rows, size_t cols) { + if (!model) { + set_error("Invalid model pointer"); + return -1; + } + if (model->num_layers >= MAX_LAYERS) { + set_error("Maximum number of layers reached"); + return -1; + } + + Layer* layer = &model->layers[model->num_layers]; + layer->weights = secure_malloc(rows * cols * sizeof(float)); + if (!layer->weights) { + set_error("Failed to allocate memory for layer weights"); + return -1; + } + + memcpy(layer->weights, weights, rows * cols * sizeof(float)); + layer->rows = rows; + layer->cols = cols; + layer->is_secure_allocated = 1; + model->num_layers++; + + return 0; +} + +int save_model(const Model* model, const char* filename, const uint8_t* public_key, size_t public_key_len) { + if (!model || !filename || !public_key) { + set_error("Invalid parameters for save_model"); + return -1; + } + + FILE* file = fopen(filename, "wb"); + if (!file) { + set_error("Failed to open file for writing"); + return -1; + } + + // Write number of layers + fwrite(&model->num_layers, sizeof(size_t), 1, file); + + // Write each layer + for (size_t i = 0; i < model->num_layers; i++) { + const Layer* layer = &model->layers[i]; + fwrite(&layer->rows, sizeof(size_t), 1, file); + fwrite(&layer->cols, sizeof(size_t), 1, file); + + // Encrypt weights + uint8_t* encrypted_weights; + size_t encrypted_weights_len; + if (encrypt(public_key, public_key_len, (uint8_t*)layer->weights, + layer->rows * layer->cols * sizeof(float), + &encrypted_weights, &encrypted_weights_len) != 0) { + set_error("Failed to encrypt layer weights"); + fclose(file); + return -1; + } + + // Write encrypted weights + fwrite(&encrypted_weights_len, sizeof(size_t), 1, file); + fwrite(encrypted_weights, 1, encrypted_weights_len, file); + secure_free((void**)&encrypted_weights); + } + + // Write public key + fwrite(&public_key_len, sizeof(size_t), 1, file); + fwrite(public_key, 1, public_key_len, file); + + fclose(file); + return 0; +} + +Model* load_model(const char* filename, const uint8_t* secret_key, size_t secret_key_len) { + if (!filename || !secret_key) { + set_error("Invalid parameters for load_model"); + return NULL; + } + + FILE* file = fopen(filename, "rb"); + if (!file) { + set_error("Failed to open file for reading"); + return NULL; + } + + Model* model = create_model(); + if (!model) { + fclose(file); + return NULL; + } + + // Read number of layers + if (fread(&model->num_layers, sizeof(size_t), 1, file) != 1) { + set_error("Failed to read number of layers"); + free_model(model); + fclose(file); + return NULL; + } + + // Read each layer + for (size_t i = 0; i < model->num_layers; i++) { + Layer* layer = &model->layers[i]; + if (fread(&layer->rows, sizeof(size_t), 1, file) != 1 || + fread(&layer->cols, sizeof(size_t), 1, file) != 1) { + set_error("Failed to read layer dimensions"); + free_model(model); + fclose(file); + return NULL; + } + + // Read encrypted weights + size_t encrypted_weights_len; + if (fread(&encrypted_weights_len, sizeof(size_t), 1, file) != 1) { + set_error("Failed to read encrypted weights length"); + free_model(model); + fclose(file); + return NULL; + } + + uint8_t* encrypted_weights = secure_malloc(encrypted_weights_len); + if (!encrypted_weights) { + set_error("Failed to allocate memory for encrypted weights"); + free_model(model); + fclose(file); + return NULL; + } + + if (fread(encrypted_weights, 1, encrypted_weights_len, file) != encrypted_weights_len) { + set_error("Failed to read encrypted weights"); + secure_free((void**)&encrypted_weights); + free_model(model); + fclose(file); + return NULL; + } + + // Decrypt weights + uint8_t* decrypted_weights; + size_t decrypted_weights_len; + if (decrypt(secret_key, secret_key_len, encrypted_weights, encrypted_weights_len, + &decrypted_weights, &decrypted_weights_len) != 0) { + set_error("Failed to decrypt layer weights"); + secure_free((void**)&encrypted_weights); + free_model(model); + fclose(file); + return NULL; + } + + secure_free((void**)&encrypted_weights); + + layer->weights = (float*)decrypted_weights; + layer->is_secure_allocated = 1; // Mark as secure allocated + if (decrypted_weights_len != layer->rows * layer->cols * sizeof(float)) { + set_error("Decrypted weights size mismatch"); + free_model(model); + fclose(file); + return NULL; + } + } + + // Read public key + if (fread(&model->public_key_len, sizeof(size_t), 1, file) != 1) { + set_error("Failed to read public key length"); + free_model(model); + fclose(file); + return NULL; + } + + model->public_key = secure_malloc(model->public_key_len); + if (!model->public_key) { + set_error("Failed to allocate memory for public key"); + free_model(model); + fclose(file); + return NULL; + } + + if (fread(model->public_key, 1, model->public_key_len, file) != model->public_key_len) { + set_error("Failed to read public key"); + free_model(model); + fclose(file); + return NULL; + } + + fclose(file); + return model; +} + +int inference(const Model* model, const float* input, size_t input_size, + float* output, size_t output_size) { + if (!model || !input || !output) { + set_error("Invalid parameters for inference"); + return -1; + } + + if (model->num_layers == 0) { + set_error("Model has no layers"); + return -1; + } + + if (input_size != model->layers[0].cols) { + set_error("Input size mismatch"); + return -1; + } + + if (output_size != model->layers[model->num_layers - 1].rows) { + set_error("Output size mismatch"); + return -1; + } + + float* temp_input = (float*)secure_malloc(input_size * sizeof(float)); + float* temp_output = (float*)secure_malloc(output_size * sizeof(float)); + if (!temp_input || !temp_output) { + set_error("Failed to allocate memory for temporary buffers"); + secure_free((void**)&temp_input); + secure_free((void**)&temp_output); + return -1; + } + + memcpy(temp_input, input, input_size * sizeof(float)); + + for (size_t i = 0; i < model->num_layers; i++) { + const Layer* layer = &model->layers[i]; + for (size_t j = 0; j < layer->rows; j++) { + float sum = 0; + for (size_t k = 0; k < layer->cols; k++) { + sum += layer->weights[j * layer->cols + k] * temp_input[k]; + } + temp_output[j] = (sum > 0) ? sum : 0; // ReLU activation + } + memcpy(temp_input, temp_output, layer->rows * sizeof(float)); + } + + memcpy(output, temp_output, output_size * sizeof(float)); + + secure_free((void**)&temp_input); + secure_free((void**)&temp_output); + + return 0; +} + +void free_model(Model* model) { + if (model) { + printf("Freeing model at %p\n", (void*)model); + for (size_t i = 0; i < model->num_layers; i++) { + printf("Freeing layer %zu weights at %p\n", i, (void*)model->layers[i].weights); + if (model->layers[i].weights) { + secure_free((void**)&model->layers[i].weights); + } + } + if (model->public_key) { + printf("Freeing model public key at %p\n", (void*)model->public_key); + secure_free((void**)&model->public_key); + } + secure_free((void**)&model); + } else { + printf("Attempted to free NULL model.\n"); + } +} + +int get_model_public_key(const Model* model, const uint8_t** public_key, size_t* public_key_len) { + if (!model || !public_key || !public_key_len) { + set_error("Invalid parameters for get_model_public_key"); + return -1; + } + + *public_key = model->public_key; + *public_key_len = model->public_key_len; + + return 0; +} diff --git a/src/utils.c b/src/utils.c new file mode 100644 index 0000000..7091a9c --- /dev/null +++ b/src/utils.c @@ -0,0 +1,450 @@ +#include +#include +#include +#include +#include +#include "../include/utils.h" + +#define MAX_ERROR_LENGTH 256 + +// Global error state +static char error_message[MAX_ERROR_LENGTH] = {0}; + +// Function to set error message +void set_utils_error(const char* message) { + strncpy(error_message, message, MAX_ERROR_LENGTH - 1); + error_message[MAX_ERROR_LENGTH - 1] = '\0'; +} + +// Function to get error message +const char* get_utils_error() { + return error_message; +} + +typedef struct { + size_t size; + char data[]; +} secure_alloc_t; + +/** + * Securely allocate memory and zero it out + * + * @param size The number of bytes to allocate + * @return A pointer to the allocated memory, or NULL on failure + */ +void* secure_malloc(size_t size) { + secure_alloc_t* alloc = malloc(sizeof(secure_alloc_t) + size); + if (alloc) { + alloc->size = size; + memset(alloc->data, 0, size); + printf("secure_malloc: Allocated %zu bytes at %p (returned %p)\n", size, (void*)alloc, (void*)alloc->data); + return alloc->data; + } + set_utils_error("Failed to allocate memory"); + return NULL; +} + +/** + * Securely free memory and set pointer to NULL + * + * @param ptr Pointer to the memory to free + */ +void secure_free(void** ptr) { + if (ptr != NULL && *ptr != NULL) { + secure_alloc_t* alloc = (secure_alloc_t*)((char*)*ptr - offsetof(secure_alloc_t, data)); + printf("secure_free: Freeing %zu bytes at %p (original pointer %p)\n", alloc->size, (void*)alloc, (void*)*ptr); + memset(alloc->data, 0, alloc->size); + free(alloc); + *ptr = NULL; + } else { + printf("secure_free: Nothing to free (ptr is NULL or *ptr is NULL)\n"); + } +} + +/** + * Convert a float array to a byte array + * + * @param float_array The input float array + * @param float_array_len The length of the float array + * @param byte_array Pointer to store the output byte array + * @param byte_array_len Pointer to store the length of the byte array + * @return 0 on success, -1 on failure + */ +int float_to_byte_array(const float* float_array, size_t float_array_len, + uint8_t** byte_array, size_t* byte_array_len) { + *byte_array_len = float_array_len * sizeof(float); + *byte_array = secure_malloc(*byte_array_len); + if (!*byte_array) { + set_utils_error("Failed to allocate memory for byte array"); + return -1; + } + memcpy(*byte_array, float_array, *byte_array_len); + return 0; +} + +/** + * Convert a byte array to a float array + * + * @param byte_array The input byte array + * @param byte_array_len The length of the byte array + * @param float_array Pointer to store the output float array + * @param float_array_len Pointer to store the length of the float array + * @return 0 on success, -1 on failure + */ +int byte_to_float_array(const uint8_t* byte_array, size_t byte_array_len, + float** float_array, size_t* float_array_len) { + if (byte_array_len % sizeof(float) != 0) { + set_utils_error("Byte array length is not a multiple of sizeof(float)"); + return -1; + } + *float_array_len = byte_array_len / sizeof(float); + *float_array = secure_malloc(byte_array_len); + if (!*float_array) { + set_utils_error("Failed to allocate memory for float array"); + return -1; + } + memcpy(*float_array, byte_array, byte_array_len); + return 0; +} + +/** + * Generate a random float array + * + * @param len The length of the array to generate + * @param min The minimum value (inclusive) + * @param max The maximum value (exclusive) + * @return A pointer to the generated array, or NULL on failure + */ +float* generate_random_float_array(size_t len, float min, float max) { + float* array = (float*)secure_malloc(len * sizeof(float)); + if (!array) { + set_utils_error("Failed to allocate memory for random float array"); + return NULL; + } + + for (size_t i = 0; i < len; i++) { + array[i] = min + (max - min) * ((float)rand() / RAND_MAX); + } + + return array; +} + +/** + * Compute the dot product of two float arrays + * + * @param a The first array + * @param b The second array + * @param len The length of both arrays + * @return The dot product + */ +float dot_product(const float* a, const float* b, size_t len) { + float result = 0.0f; + for (size_t i = 0; i < len; i++) { + result += a[i] * b[i]; + } + return result; +} + +/** + * Perform element-wise addition of two float arrays + * + * @param a The first array + * @param b The second array + * @param result The array to store the result + * @param len The length of all arrays + */ +void vector_add(const float* a, const float* b, float* result, size_t len) { + for (size_t i = 0; i < len; i++) { + result[i] = a[i] + b[i]; + } +} + +/** + * Perform element-wise subtraction of two float arrays + * + * @param a The first array + * @param b The second array + * @param result The array to store the result + * @param len The length of all arrays + */ +void vector_subtract(const float* a, const float* b, float* result, size_t len) { + for (size_t i = 0; i < len; i++) { + result[i] = a[i] - b[i]; + } +} + +/** + * Compute the L2 norm (Euclidean norm) of a float array + * + * @param a The input array + * @param len The length of the array + * @return The L2 norm + */ +float l2_norm(const float* a, size_t len) { + float sum = 0.0f; + for (size_t i = 0; i < len; i++) { + sum += a[i] * a[i]; + } + return sqrt(sum); +} + +/** + * Normalize a float array to have unit L2 norm + * + * @param a The input array + * @param result The array to store the normalized result + * @param len The length of both arrays + */ +void normalize(const float* a, float* result, size_t len) { + float norm = l2_norm(a, len); + if (norm == 0) { + set_utils_error("Cannot normalize zero vector"); + return; + } + for (size_t i = 0; i < len; i++) { + result[i] = a[i] / norm; + } +} + +/** + * Compute the softmax of a float array + * + * @param a The input array + * @param result The array to store the softmax result + * @param len The length of both arrays + */ +void softmax(const float* a, float* result, size_t len) { + float max = a[0]; + for (size_t i = 1; i < len; i++) { + if (a[i] > max) { + max = a[i]; + } + } + + float sum = 0.0f; + for (size_t i = 0; i < len; i++) { + result[i] = exp(a[i] - max); + sum += result[i]; + } + + for (size_t i = 0; i < len; i++) { + result[i] /= sum; + } +} + +/** + * Print a float array + * + * @param array The array to print + * @param len The length of the array + * @param name The name of the array (for display purposes) + */ +void print_float_array(const float* array, size_t len, const char* name) { + printf("%s: [", name); + for (size_t i = 0; i < len; i++) { + printf("%f", array[i]); + if (i < len - 1) { + printf(", "); + } + } + printf("]\n"); +} + +/** + * Save a float array to a file + * + * @param filename The name of the file to save to + * @param array The array to save + * @param len The length of the array + * @return 0 on success, -1 on failure + */ +int save_float_array(const char* filename, const float* array, size_t len) { + FILE* file = fopen(filename, "wb"); + if (!file) { + set_utils_error("Failed to open file for writing"); + return -1; + } + + size_t written = fwrite(array, sizeof(float), len, file); + if (written != len) { + set_utils_error("Failed to write entire array to file"); + fclose(file); + return -1; + } + + fclose(file); + return 0; +} + +/** + * Load a float array from a file + * + * @param filename The name of the file to load from + * @param array Pointer to store the loaded array + * @param len Pointer to store the length of the loaded array + * @return 0 on success, -1 on failure + */ +int load_float_array(const char* filename, float** array, size_t* len) { + FILE* file = fopen(filename, "rb"); + if (!file) { + set_utils_error("Failed to open file for reading"); + return -1; + } + + fseek(file, 0, SEEK_END); + long file_size = ftell(file); + fseek(file, 0, SEEK_SET); + + if (file_size % sizeof(float) != 0) { + set_utils_error("File size is not a multiple of sizeof(float)"); + fclose(file); + return -1; + } + + *len = file_size / sizeof(float); + *array = secure_malloc(file_size); + if (!*array) { + set_utils_error("Failed to allocate memory for loaded array"); + fclose(file); + return -1; + } + + size_t read = fread(*array, sizeof(float), *len, file); + if (read != *len) { + set_utils_error("Failed to read entire array from file"); + secure_free((void**)array); + fclose(file); + return -1; + } + + fclose(file); + return 0; +} + +/** + * Initialize the random number generator + * This function should be called once at the start of the program + */ +void init_random() { + srand(time(NULL)); +} + +/** + * Compare two float values with a small epsilon to account for floating-point imprecision + * + * @param a The first float + * @param b The second float + * @param epsilon The maximum difference to consider the floats equal + * @return 1 if the floats are equal within epsilon, 0 otherwise + */ +int float_equal(float a, float b, float epsilon) { + return fabs(a - b) < epsilon; +} + +/** + * Clip a float value to a specified range + * + * @param value The value to clip + * @param min The minimum allowed value + * @param max The maximum allowed value + * @return The clipped value + */ +float clip(float value, float min, float max) { + if (value < min) return min; + if (value > max) return max; + return value; +} + +/** + * Compute the sigmoid of a float value + * + * @param x The input value + * @return The sigmoid of x + */ +float sigmoid(float x) { + return 1.0f / (1.0f + exp(-x)); +} + +/** + * Apply the sigmoid function element-wise to a float array + * + * @param a The input array + * @param result The array to store the result + * @param len The length of both arrays + */ +void sigmoid_array(const float* a, float* result, size_t len) { + for (size_t i = 0; i < len; i++) { + result[i] = sigmoid(a[i]); + } +} + +/** + * Compute the hyperbolic tangent (tanh) of a float value + * + * @param x The input value + * @return The tanh of x + */ +float tanh_float(float x) { + return tanh(x); +} + +/** + * Apply the tanh function element-wise to a float array + * + * @param a The input array + * @param result The array to store the result + * @param len The length of both arrays + */ +void tanh_array(const float* a, float* result, size_t len) { + for (size_t i = 0; i < len; i++) { + result[i] = tanh_float(a[i]); + } +} + +/** + * Compute the ReLU (Rectified Linear Unit) of a float value + * + * @param x The input value + * @return The ReLU of x + */ +float relu(float x) { + return x > 0 ? x : 0; +} + +/** + * Apply the ReLU function element-wise to a float array + * + * @param a The input array + * @param result The array to store the result + * @param len The length of both arrays + */ +void relu_array(const float* a, float* result, size_t len) { + for (size_t i = 0; i < len; i++) { + result[i] = relu(a[i]); + } +} + +/** + * Compute the leaky ReLU of a float value + * + * @param x The input value + * @param alpha The slope for negative values + * @return The leaky ReLU of x + */ +float leaky_relu(float x, float alpha) { + return x > 0 ? x : alpha * x; +} + +/** + * Apply the leaky ReLU function element-wise to a float array + * + * @param a The input array + * @param result The array to store the result + * @param len The length of both arrays + * @param alpha The slope for negative values + */ +void leaky_relu_array(const float* a, float* result, size_t len, float alpha) { + for (size_t i = 0; i < len; i++) { + result[i] = leaky_relu(a[i], alpha); + } +} diff --git a/tests/test_all.c b/tests/test_all.c new file mode 100644 index 0000000..00384ff --- /dev/null +++ b/tests/test_all.c @@ -0,0 +1,208 @@ +#include +#include +#include +#include +#include +#include "../include/encryption.h" +#include "../include/model.h" +#include "../include/utils.h" + +#define TEST_MESSAGE "Hello, LLM and Quantum World!" +#define EPSILON 1e-6 +#define TEST_MODEL_FILE "test_model.bin" + +// Helper function to compare float arrays +int compare_float_arrays(const float* a, const float* b, size_t len, float epsilon) { + for (size_t i = 0; i < len; i++) { + if (fabs(a[i] - b[i]) > epsilon) { + return 0; + } + } + return 1; +} + +// Encryption Tests +void test_key_generation() { + printf("Testing key generation...\n"); + + uint8_t *public_key, *secret_key; + size_t public_key_len, secret_key_len; + + int result = generate_keypair(&public_key, &public_key_len, &secret_key, &secret_key_len); + assert(result == 0); + assert(public_key != NULL); + assert(secret_key != NULL); + assert(public_key_len > 0); + assert(secret_key_len > 0); + + printf("Public key length: %zu\n", public_key_len); + printf("Secret key length: %zu\n", secret_key_len); + + cleanup(public_key); + cleanup(secret_key); + + printf("Key generation test passed.\n\n"); +} + +void test_encryption_decryption() { + printf("Testing encryption and decryption...\n"); + + uint8_t *public_key, *secret_key; + size_t public_key_len, secret_key_len; + + int result = generate_keypair(&public_key, &public_key_len, &secret_key, &secret_key_len); + assert(result == 0); + + const uint8_t *plaintext = (const uint8_t *)TEST_MESSAGE; + size_t plaintext_len = strlen(TEST_MESSAGE); + + uint8_t *ciphertext; + size_t ciphertext_len; + + result = encrypt(public_key, public_key_len, plaintext, plaintext_len, &ciphertext, &ciphertext_len); + assert(result == 0); + assert(ciphertext != NULL); + assert(ciphertext_len > plaintext_len); + + uint8_t *decrypted; + size_t decrypted_len; + + result = decrypt(secret_key, secret_key_len, ciphertext, ciphertext_len, &decrypted, &decrypted_len); + assert(result == 0); + assert(decrypted != NULL); + assert(decrypted_len == plaintext_len); + assert(memcmp(plaintext, decrypted, plaintext_len) == 0); + + printf("Original message: %s\n", TEST_MESSAGE); + printf("Decrypted message: %s\n", decrypted); + + cleanup(public_key); + cleanup(secret_key); + cleanup(ciphertext); + cleanup(decrypted); + + printf("Encryption and decryption test passed.\n\n"); +} + +// Model Tests +void test_create_model() { + printf("Testing model creation...\n"); + + Model* model = create_model(); + assert(model != NULL); + + free_model(model); + + printf("Model creation test passed.\n\n"); +} + +void test_add_layer() { + printf("Testing add layer...\n"); + + Model* model = create_model(); + assert(model != NULL); + + float weights[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + int result = add_layer(model, weights, 2, 3); + assert(result == 0); + + free_model(model); + + printf("Add layer test passed.\n\n"); +} + +void test_save_load_model() { + printf("Testing save and load model...\n"); + + // Create and populate a model + Model* model = create_model(); + assert(model != NULL); + + float weights1[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + float weights2[] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; + add_layer(model, weights1, 2, 3); + add_layer(model, weights2, 5, 1); + + // Generate a key pair for encryption + uint8_t *public_key, *secret_key; + size_t public_key_len, secret_key_len; + int result = generate_keypair(&public_key, &public_key_len, &secret_key, &secret_key_len); + assert(result == 0); + + // Save the model + result = save_model(model, TEST_MODEL_FILE, public_key, public_key_len); + assert(result == 0); + + // Free the original model + free_model(model); + + // Load the model + Model* loaded_model = load_model(TEST_MODEL_FILE, secret_key, secret_key_len); + assert(loaded_model != NULL); + + // Clean up + free_model(loaded_model); + cleanup(public_key); + cleanup(secret_key); + remove(TEST_MODEL_FILE); + + printf("Save and load model test passed.\n\n"); +} + +void test_inference() { + printf("Testing model inference...\n"); + + // Create and populate a model + Model* model = create_model(); + assert(model != NULL); + + // Simple model: 2 layers, 3 inputs, 2 hidden neurons, 1 output + float weights1[] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f}; + float weights2[] = {0.7f, 0.8f}; + add_layer(model, weights1, 2, 3); + add_layer(model, weights2, 1, 2); + + // Prepare input + float input[] = {1.0f, 2.0f, 3.0f}; + float output[1]; + + // Perform inference + int result = inference(model, input, 3, output, 1); + assert(result == 0); + + // Expected output: ReLU(0.7 * ReLU(0.1*1 + 0.2*2 + 0.3*3) + 0.8 * ReLU(0.4*1 + 0.5*2 + 0.6*3)) + // = ReLU(0.7 * ReLU(1.4) + 0.8 * ReLU(3.2)) + // = ReLU(0.7 * 1.4 + 0.8 * 3.2) + // = ReLU(0.98 + 2.56) + // = 3.54 + float expected_output = 3.54f; + assert(fabs(output[0] - expected_output) < EPSILON); + + // Clean up + free_model(model); + + printf("Model inference test passed.\n\n"); +} + +int main() { + printf("Starting all tests...\n\n"); + + init_encryption(); + init_random(); + + // Encryption tests + test_key_generation(); + test_encryption_decryption(); + + // Model tests + test_create_model(); + test_add_layer(); + test_save_load_model(); + test_inference(); + + cleanup_encryption(); + + printf("All tests passed successfully!\n"); + + return 0; +}