From 9922332834d0620fb86aef8ca723b15f292e17d7 Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Thu, 26 Dec 2024 14:09:59 +0000 Subject: [PATCH] raw-compute testing is now working --- .../src/api/execution/buffered_compute.rs | 0 modules/c-wrapper/src/api/execution/mod.rs | 1 + .../execution/test_raw_compute.py} | 25 +------- .../c-wrapper/tests/api/storage/test_meta.py | 60 ++++++++++++++++++- .../tests/test_utils/return_structs.py | 12 +++- 5 files changed, 73 insertions(+), 25 deletions(-) create mode 100644 modules/c-wrapper/src/api/execution/buffered_compute.rs rename modules/c-wrapper/tests/{test_execution.py => api/execution/test_raw_compute.py} (66%) diff --git a/modules/c-wrapper/src/api/execution/buffered_compute.rs b/modules/c-wrapper/src/api/execution/buffered_compute.rs new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/src/api/execution/mod.rs b/modules/c-wrapper/src/api/execution/mod.rs index 1bff97b..590975c 100644 --- a/modules/c-wrapper/src/api/execution/mod.rs +++ b/modules/c-wrapper/src/api/execution/mod.rs @@ -1,2 +1,3 @@ //! The C API for executing ML models. pub mod raw_compute; +pub mod buffered_compute; diff --git a/modules/c-wrapper/tests/test_execution.py b/modules/c-wrapper/tests/api/execution/test_raw_compute.py similarity index 66% rename from modules/c-wrapper/tests/test_execution.py rename to modules/c-wrapper/tests/api/execution/test_raw_compute.py index 7bb4af4..cbe3e64 100644 --- a/modules/c-wrapper/tests/test_execution.py +++ b/modules/c-wrapper/tests/api/execution/test_raw_compute.py @@ -1,30 +1,11 @@ import ctypes -import platform -from pathlib import Path from unittest import TestCase, main + from test_utils.c_lib_loader import load_library +from test_utils.return_structs import FileInfo, Vecf32Return from test_utils.routes import TEST_SURML_PATH -class FileInfo(ctypes.Structure): - _fields_ = [ - ("file_id", ctypes.c_char_p), - ("name", ctypes.c_char_p), - ("description", ctypes.c_char_p), - ("version", ctypes.c_char_p), - ("error_message", ctypes.c_char_p), # Optional error message - ] - -class Vecf32Return(ctypes.Structure): - _fields_ = [ - ("data", ctypes.POINTER(ctypes.c_float)), # Pointer to f32 array - ("length", ctypes.c_size_t), # Length of the array - ("capacity", ctypes.c_size_t), # Capacity of the array - ("is_error", ctypes.c_int), # Indicates if it's an error - ("error_message", ctypes.c_char_p), # Optional error message - ] - - class TestExecution(TestCase): def setUp(self) -> None: @@ -60,7 +41,7 @@ def test_raw_compute(self): # Extract and verify the computation result outcome = [result.data[i] for i in range(result.length)] - print(f"Computation Result: {outcome}") + self.assertEqual(1.8246129751205444, outcome[0]) # Free allocated memory self.lib.free_vecf32_return(result) diff --git a/modules/c-wrapper/tests/api/storage/test_meta.py b/modules/c-wrapper/tests/api/storage/test_meta.py index 4873fa0..1d02f77 100644 --- a/modules/c-wrapper/tests/api/storage/test_meta.py +++ b/modules/c-wrapper/tests/api/storage/test_meta.py @@ -3,10 +3,12 @@ """ import ctypes from unittest import TestCase, main +from typing import Optional +import os from test_utils.c_lib_loader import load_library -from test_utils.return_structs import EmptyReturn, FileInfo -from test_utils.routes import TEST_SURML_PATH +from test_utils.return_structs import EmptyReturn, FileInfo, StringReturn +from test_utils.routes import TEST_SURML_PATH, TEST_ONNX_FILE_PATH, ASSETS_PATH class TestMeta(TestCase): @@ -33,12 +35,24 @@ def setUp(self) -> None: self.lib.load_model.restype = FileInfo self.lib.load_model.argtypes = [ctypes.c_char_p] self.lib.free_file_info.argtypes = [FileInfo] + # define the load raw model signature + self.lib.load_cached_raw_model.restype = StringReturn + self.lib.load_cached_raw_model.argtypes = [ctypes.c_char_p] + # define the dave model signature + self.lib.save_model.restype = EmptyReturn + self.lib.save_model.argtypes = [ctypes.c_char_p, ctypes.c_char_p] + # load the model for tests self.model: FileInfo = self.lib.load_model(str(TEST_SURML_PATH).encode('utf-8')) self.file_id = self.model.file_id.decode('utf-8') + self.temp_test_id: Optional[str] = None def tearDown(self) -> None: self.lib.free_file_info(self.model) + # remove the temp surml file created in assets if present + if self.test_temp_surml_file_path is not None: + os.remove(self.test_temp_surml_file_path) + def test_null_protection(self): placeholder = "placeholder".encode('utf-8') file_id = self.file_id.encode('utf-8') @@ -81,6 +95,48 @@ def test_model_not_found(self): self.assertEqual(1, outcome.is_error) self.assertEqual("Model not found", outcome.error_message.decode('utf-8')) + def test_add_metadata_and_save(self): + file_id: StringReturn = self.lib.load_cached_raw_model(str(TEST_SURML_PATH).encode('utf-8')) + self.assertEqual(0, file_id.is_error) + + decoded_file_id = file_id.string.decode('utf-8') + self.temp_test_id = decoded_file_id + + self.assertEqual( + 0, + self.lib.add_name(file_id.string, "test name".encode('utf-8')).is_error + ) + self.assertEqual( + 0, + self.lib.add_description(file_id.string, "test description".encode('utf-8')).is_error + ) + self.assertEqual( + 0, + self.lib.add_version(file_id.string, "0.0.1".encode('utf-8')).is_error + ) + self.assertEqual( + 0, + self.lib.add_author(file_id.string, "test author".encode('utf-8')).is_error + ) + self.assertEqual( + 0, + self.lib.save_model(self.test_temp_surml_file_path.encode("utf-8"), file_id.string).is_error + ) + + outcome: FileInfo = self.lib.load_model(self.test_temp_surml_file_path.encode('utf-8')) + self.assertEqual(0, outcome.is_error) + self.assertEqual("test name", outcome.name.decode('utf-8')) + self.assertEqual("test description", outcome.description.decode('utf-8')) + self.assertEqual("0.0.1", outcome.version.decode('utf-8')) + + + @property + def test_temp_surml_file_path(self) -> Optional[str]: + if self.temp_test_id is None: + return None + return str(ASSETS_PATH.joinpath(f"{self.temp_test_id}.surml")) + + if __name__ == '__main__': main() diff --git a/modules/c-wrapper/tests/test_utils/return_structs.py b/modules/c-wrapper/tests/test_utils/return_structs.py index bc0479f..3baaa35 100644 --- a/modules/c-wrapper/tests/test_utils/return_structs.py +++ b/modules/c-wrapper/tests/test_utils/return_structs.py @@ -1,7 +1,7 @@ """ Defines all the C structs that are used in the tests. """ -from ctypes import Structure, c_char_p, c_int +from ctypes import Structure, c_char_p, c_int, c_size_t, POINTER, c_float class StringReturn(Structure): @@ -52,3 +52,13 @@ class FileInfo(Structure): ("error_message", c_char_p), # Corresponds to *mut c_char ("is_error", c_int) # Corresponds to c_int ] + + +class Vecf32Return(Structure): + _fields_ = [ + ("data", POINTER(c_float)), # Pointer to f32 array + ("length", c_size_t), # Length of the array + ("capacity", c_size_t), # Capacity of the array + ("is_error", c_int), # Indicates if it's an error + ("error_message", c_char_p), # Optional error message + ]