Skip to content

Commit

Permalink
additional ts tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fulpm committed Jan 31, 2025
1 parent 18422fe commit ad74d5f
Showing 1 changed file with 34 additions and 8 deletions.
42 changes: 34 additions & 8 deletions unit_tests/ut_h_ablate_ts.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
trained_cases (null)
ablated_indices []
warnings []
train_statuses []
)

;create a fresh trainee
Expand Down Expand Up @@ -61,10 +62,16 @@
(call_entity "howso" "train" (assoc
session session
features features
cases (unzip cases (range (current_index 2) (+ (current_index 2) 19)))
cases
;batch up to 20 (filter out nulls if batch exceeds remaining cases)
(filter (unzip
cases
(range (current_index 2) (+ (current_index 2) 19))
))
))
)
(accum (assoc
train_statuses (get response 0)
warnings (or (get response [1 "warnings"]) [])
ablated_indices
(map
Expand All @@ -77,6 +84,14 @@
0 (size cases) 20
)

;verify all trains were successful
(print "All train batches returned success status: ")
(call assert_same (assoc
obs (size (filter (lambda (= 0 (current_value))) train_statuses))
exp 0
))
(call exit_if_failures (assoc msg "Trains completed succssfully"))

(if (size expected_warnings)
;match that at least one of the expected warnings is raised, and no others
(seq
Expand Down Expand Up @@ -145,6 +160,17 @@
exp (size cases)
))

(print "Ablated indices do not contain nulls: ")
(call assert_false (assoc
obs (contains_value ablated_indices (null))
))

(print "Ablated indices are unique: ")
(call assert_same (assoc
obs (size (values abalted_indices (true)))
exp (size abalted_indices)
))

(print "Session training indices match original indices: ")
(call assert_same (assoc
obs original_indices
Expand Down Expand Up @@ -174,8 +200,13 @@
exp 0
))

(print "Series indices do not contain nulls: ")
(call assert_false (assoc
obs (contains_value series_indices (null))
))

;sort trained cases by the date column
(assign (assoc
;sort trained cases by the date column
trained_cases
(sort
(lambda (let
Expand All @@ -193,11 +224,6 @@
)
))

(print "Series indices do not contain nulls: ")
(call assert_false (assoc
obs (contains_value series_indices (null))
))

;per series checks
(map
(lambda (let
Expand Down Expand Up @@ -278,7 +304,7 @@
(map
;map in the expected session training index
(lambda (append (current_value) (current_index)) )
;mixed indices such that each series is still sequetnail but most train batches include both series
;mixed indices such that each series is still sequential but most train batches include both series
(unzip
(tail dataset)
[
Expand Down

0 comments on commit ad74d5f

Please sign in to comment.