Skip to content

Commit

Permalink
raw-compute testing is now working
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwellflitton committed Dec 26, 2024
1 parent ac1c133 commit 9922332
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 25 deletions.
Empty file.
1 change: 1 addition & 0 deletions modules/c-wrapper/src/api/execution/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
//! The C API for executing ML models.
pub mod raw_compute;
pub mod buffered_compute;
Original file line number Diff line number Diff line change
@@ -1,30 +1,11 @@
import ctypes
import platform
from pathlib import Path
from unittest import TestCase, main

from test_utils.c_lib_loader import load_library
from test_utils.return_structs import FileInfo, Vecf32Return
from test_utils.routes import TEST_SURML_PATH


class FileInfo(ctypes.Structure):
_fields_ = [
("file_id", ctypes.c_char_p),
("name", ctypes.c_char_p),
("description", ctypes.c_char_p),
("version", ctypes.c_char_p),
("error_message", ctypes.c_char_p), # Optional error message
]

class Vecf32Return(ctypes.Structure):
_fields_ = [
("data", ctypes.POINTER(ctypes.c_float)), # Pointer to f32 array
("length", ctypes.c_size_t), # Length of the array
("capacity", ctypes.c_size_t), # Capacity of the array
("is_error", ctypes.c_int), # Indicates if it's an error
("error_message", ctypes.c_char_p), # Optional error message
]


class TestExecution(TestCase):

def setUp(self) -> None:
Expand Down Expand Up @@ -60,7 +41,7 @@ def test_raw_compute(self):

# Extract and verify the computation result
outcome = [result.data[i] for i in range(result.length)]
print(f"Computation Result: {outcome}")
self.assertEqual(1.8246129751205444, outcome[0])

# Free allocated memory
self.lib.free_vecf32_return(result)
Expand Down
60 changes: 58 additions & 2 deletions modules/c-wrapper/tests/api/storage/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"""
import ctypes
from unittest import TestCase, main
from typing import Optional
import os

from test_utils.c_lib_loader import load_library
from test_utils.return_structs import EmptyReturn, FileInfo
from test_utils.routes import TEST_SURML_PATH
from test_utils.return_structs import EmptyReturn, FileInfo, StringReturn
from test_utils.routes import TEST_SURML_PATH, TEST_ONNX_FILE_PATH, ASSETS_PATH


class TestMeta(TestCase):
Expand All @@ -33,12 +35,24 @@ def setUp(self) -> None:
self.lib.load_model.restype = FileInfo
self.lib.load_model.argtypes = [ctypes.c_char_p]
self.lib.free_file_info.argtypes = [FileInfo]
# define the load raw model signature
self.lib.load_cached_raw_model.restype = StringReturn
self.lib.load_cached_raw_model.argtypes = [ctypes.c_char_p]
# define the dave model signature
self.lib.save_model.restype = EmptyReturn
self.lib.save_model.argtypes = [ctypes.c_char_p, ctypes.c_char_p]
# load the model for tests
self.model: FileInfo = self.lib.load_model(str(TEST_SURML_PATH).encode('utf-8'))
self.file_id = self.model.file_id.decode('utf-8')
self.temp_test_id: Optional[str] = None

def tearDown(self) -> None:
self.lib.free_file_info(self.model)

# remove the temp surml file created in assets if present
if self.test_temp_surml_file_path is not None:
os.remove(self.test_temp_surml_file_path)

def test_null_protection(self):
placeholder = "placeholder".encode('utf-8')
file_id = self.file_id.encode('utf-8')
Expand Down Expand Up @@ -81,6 +95,48 @@ def test_model_not_found(self):
self.assertEqual(1, outcome.is_error)
self.assertEqual("Model not found", outcome.error_message.decode('utf-8'))

def test_add_metadata_and_save(self):
file_id: StringReturn = self.lib.load_cached_raw_model(str(TEST_SURML_PATH).encode('utf-8'))
self.assertEqual(0, file_id.is_error)

decoded_file_id = file_id.string.decode('utf-8')
self.temp_test_id = decoded_file_id

self.assertEqual(
0,
self.lib.add_name(file_id.string, "test name".encode('utf-8')).is_error
)
self.assertEqual(
0,
self.lib.add_description(file_id.string, "test description".encode('utf-8')).is_error
)
self.assertEqual(
0,
self.lib.add_version(file_id.string, "0.0.1".encode('utf-8')).is_error
)
self.assertEqual(
0,
self.lib.add_author(file_id.string, "test author".encode('utf-8')).is_error
)
self.assertEqual(
0,
self.lib.save_model(self.test_temp_surml_file_path.encode("utf-8"), file_id.string).is_error
)

outcome: FileInfo = self.lib.load_model(self.test_temp_surml_file_path.encode('utf-8'))
self.assertEqual(0, outcome.is_error)
self.assertEqual("test name", outcome.name.decode('utf-8'))
self.assertEqual("test description", outcome.description.decode('utf-8'))
self.assertEqual("0.0.1", outcome.version.decode('utf-8'))


@property
def test_temp_surml_file_path(self) -> Optional[str]:
if self.temp_test_id is None:
return None
return str(ASSETS_PATH.joinpath(f"{self.temp_test_id}.surml"))



if __name__ == '__main__':
main()
12 changes: 11 additions & 1 deletion modules/c-wrapper/tests/test_utils/return_structs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Defines all the C structs that are used in the tests.
"""
from ctypes import Structure, c_char_p, c_int
from ctypes import Structure, c_char_p, c_int, c_size_t, POINTER, c_float


class StringReturn(Structure):
Expand Down Expand Up @@ -52,3 +52,13 @@ class FileInfo(Structure):
("error_message", c_char_p), # Corresponds to *mut c_char
("is_error", c_int) # Corresponds to c_int
]


class Vecf32Return(Structure):
_fields_ = [
("data", POINTER(c_float)), # Pointer to f32 array
("length", c_size_t), # Length of the array
("capacity", c_size_t), # Capacity of the array
("is_error", c_int), # Indicates if it's an error
("error_message", c_char_p), # Optional error message
]

0 comments on commit 9922332

Please sign in to comment.