diff --git a/ml-module/model-example/main.py b/ml-module/model-example/main.py index 11f3e7e..27b3d39 100644 --- a/ml-module/model-example/main.py +++ b/ml-module/model-example/main.py @@ -3,11 +3,11 @@ import ml -SAMPLES_LEN = 250 -TOTAL_SAMPLES = SAMPLES_LEN * 3 +TOTAL_SAMPLES = ml.get_input_length() acc_x_y_z = [0] * TOTAL_SAMPLES print("Model labels: {}".format(ml.get_labels())) +print("Input size: {}".format(TOTAL_SAMPLES)) i = 0 while True: diff --git a/ml-module/model-example/model_example.c b/ml-module/model-example/model_example.c index dad5d0e..7387e1f 100644 --- a/ml-module/model-example/model_example.c +++ b/ml-module/model-example/model_example.c @@ -10,9 +10,9 @@ .magic0 = MODEL_LABELS_MAGIC0, .header_size = 0x31, // 49 .model_offset = 0x34, // 52 - .number_of_labels = 0x04, .reserved = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, - // 33 bytes + .number_of_labels = 0x04, + // 33 bytes + 3 extra null terminators at the end .labels = { "Jumping\0" "Running\0" @@ -23,7 +23,7 @@ const unsigned int model_example[ml4f_full_model_size] = { // Manually converted ml4f_model_example_header - 0x4D444C42, 0x00340031, 0x00000004, 0x00000000, + 0x4D444C42, 0x00340031, 0x00000000, 0x04000000, 0x706D754A, 0x00676E69, 0x6E6E7552, 0x00676E69, 0x6E617453, 0x676E6964, 0x6C615700, 0x676E696B, 0x00000000, diff --git a/ml-module/src/mlmodel.c b/ml-module/src/mlmodel.c index 5051f41..9815d30 100644 --- a/ml-module/src/mlmodel.c +++ b/ml-module/src/mlmodel.c @@ -105,6 +105,14 @@ bool is_model_present(void) { return model_header != NULL; } +size_t get_input_length(void) { + ml4f_header_t *ml4f_model = get_ml4f_model(); + if (ml4f_model == NULL) { + return 0; + } + return ml4f_shape_elements(ml4f_input_shape(ml4f_model)); +} + size_t get_model_label_num(void) { ml_model_header_t *model_header = get_model_header(); return (model_header != NULL) ? model_header->number_of_labels : 0; diff --git a/ml-module/src/mlmodel.h b/ml-module/src/mlmodel.h index c6675e6..0024b8f 100644 --- a/ml-module/src/mlmodel.h +++ b/ml-module/src/mlmodel.h @@ -12,8 +12,8 @@ typedef __PACKED_STRUCT ml_model_header_t { uint32_t magic0; uint16_t header_size; // Size of this header + all label strings uint16_t model_offset; // header_size + padding to 4 bytes - uint8_t number_of_labels; // Only 255 labels supported uint8_t reserved[7]; + uint8_t number_of_labels; // Only 255 labels supported char labels[]; // Mutiple null-terminated strings, as many as number_of_labels } ml_model_header_t; @@ -29,9 +29,11 @@ typedef struct ml_prediction_s { float *predictions; } ml_prediction_t; + bool get_use_built_in_model(void); void set_use_built_in_model(bool use); bool is_model_present(void); +size_t get_input_length(void); size_t get_model_label_num(void); ml_labels_t* get_model_labels(void); size_t get_model_input_num(void); diff --git a/ml-module/src/mlmodule.c b/ml-module/src/mlmodule.c index bc65b7e..17d8809 100644 --- a/ml-module/src/mlmodule.c +++ b/ml-module/src/mlmodule.c @@ -10,6 +10,10 @@ mp_obj_t internal_model_func(size_t n_args, const mp_obj_t *args) { } static MP_DEFINE_CONST_FUN_OBJ_VAR(internal_model_func_obj, 0, internal_model_func); +mp_obj_t get_input_length_func(void) { + return mp_obj_new_int(get_input_length()); +} +static MP_DEFINE_CONST_FUN_OBJ_0(get_input_length_func_obj, get_input_length_func); mp_obj_t get_labels_func(void) { ml_labels_t* labels = get_model_labels(); @@ -106,6 +110,7 @@ static const mp_rom_map_elem_t ml_module_globals_table[] = { { MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ml) }, { MP_ROM_QSTR(MP_QSTR___init__), MP_ROM_PTR(&ml___init___obj) }, { MP_ROM_QSTR(MP_QSTR_internal_model), MP_ROM_PTR(&internal_model_func_obj) }, + { MP_ROM_QSTR(MP_QSTR_get_input_length), MP_ROM_PTR(&get_input_length_func_obj) }, { MP_ROM_QSTR(MP_QSTR_get_labels), MP_ROM_PTR(&get_labels_func_obj) }, { MP_ROM_QSTR(MP_QSTR_predict), MP_ROM_PTR(&predict_func_obj) }, };