From ad74d5f2f194d1cead799b3810c1bc58194e51c9 Mon Sep 17 00:00:00 2001 From: Matt Fulp <8397318+fulpm@users.noreply.github.com> Date: Fri, 31 Jan 2025 17:45:30 -0500 Subject: [PATCH] additional ts tests --- unit_tests/ut_h_ablate_ts.amlg | 42 +++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/unit_tests/ut_h_ablate_ts.amlg b/unit_tests/ut_h_ablate_ts.amlg index 5f4d9d35..d88d9d17 100644 --- a/unit_tests/ut_h_ablate_ts.amlg +++ b/unit_tests/ut_h_ablate_ts.amlg @@ -33,6 +33,7 @@ trained_cases (null) ablated_indices [] warnings [] + train_statuses [] ) ;create a fresh trainee @@ -61,10 +62,16 @@ (call_entity "howso" "train" (assoc session session features features - cases (unzip cases (range (current_index 2) (+ (current_index 2) 19))) + cases + ;batch up to 20 (filter out nulls if batch exceeds remaining cases) + (filter (unzip + cases + (range (current_index 2) (+ (current_index 2) 19)) + )) )) ) (accum (assoc + train_statuses (get response 0) warnings (or (get response [1 "warnings"]) []) ablated_indices (map @@ -77,6 +84,14 @@ 0 (size cases) 20 ) + ;verify all trains were successful + (print "All train batches returned success status: ") + (call assert_same (assoc + obs (size (filter (lambda (= 0 (current_value))) train_statuses)) + exp 0 + )) + (call exit_if_failures (assoc msg "Trains completed succssfully")) + (if (size expected_warnings) ;match that at least one of the expected warnings is raised, and no others (seq @@ -145,6 +160,17 @@ exp (size cases) )) + (print "Ablated indices do not contain nulls: ") + (call assert_false (assoc + obs (contains_value ablated_indices (null)) + )) + + (print "Ablated indices are unique: ") + (call assert_same (assoc + obs (size (values abalted_indices (true))) + exp (size abalted_indices) + )) + (print "Session training indices match original indices: ") (call assert_same (assoc obs original_indices @@ -174,8 +200,13 @@ exp 0 )) + (print "Series indices do not contain nulls: ") + (call assert_false (assoc + obs (contains_value series_indices (null)) + )) + + ;sort trained cases by the date column (assign (assoc - ;sort trained cases by the date column trained_cases (sort (lambda (let @@ -193,11 +224,6 @@ ) )) - (print "Series indices do not contain nulls: ") - (call assert_false (assoc - obs (contains_value series_indices (null)) - )) - ;per series checks (map (lambda (let @@ -278,7 +304,7 @@ (map ;map in the expected session training index (lambda (append (current_value) (current_index)) ) - ;mixed indices such that each series is still sequetnail but most train batches include both series + ;mixed indices such that each series is still sequential but most train batches include both series (unzip (tail dataset) [