Skip to content

Commit

Permalink
added test/train split test
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Nov 3, 2024
1 parent bf7c4ca commit ee5256b
Showing 1 changed file with 53 additions and 28 deletions.
81 changes: 53 additions & 28 deletions test/scicloj/ml/xgboost_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -228,46 +228,71 @@
first
(tc/drop-missing)
(text/->tfidf)
(tc/rename-columns {:meta :label})
(ds-mod/set-inference-target [:label]))
(tc/rename-columns {:meta :label}))

rnd-documents (shuffle (range 1000))
train-documents (into #{} (take 800 rnd-documents))
test-documents (into #{} (take-last 200 rnd-documents))


n-sparse-columns (inc (apply max (reviews :token-idx)))
model
(ml/train reviews {:model-type :xgboost/classification
:sparse-column :tfidf
:seed 123
:num-class 5
:n-sparse-columns n-sparse-columns})

train-reviews
(-> reviews
(tc/select-rows (fn [row] (contains? train-documents (:document row))))
(ds-mod/set-inference-target :label))

trueth-train
(-> train-reviews
(tc/select-columns [:document :label])
(tc/unique-by [:document :label])
(tc/order-by :document)
:label)

test-reviews reviews

raw-prediction
(->
(ml/predict test-reviews model)
(tc/select-columns [:label :document])
)

prediction
(->
raw-prediction
(tc/order-by :document))
test-reviews
(-> reviews
(tc/select-rows (fn [row] (contains? test-documents (:document row)))))

trueth
trueth-test
(-> test-reviews
(tc/select-columns [:document :label])
(tc/unique-by [:document :label])
(tc/order-by :document)
:label
)
:label)

test-review-clean
(-> test-reviews
(tc/drop-columns [:label]))


n-sparse-columns (inc (apply max (reviews :token-idx)))
model
(ml/train train-reviews {:model-type :xgboost/classification
:sparse-column :tfidf
:seed 123
:num-class 5
:n-sparse-columns n-sparse-columns})


prediction-train
(->
(ml/predict train-reviews model)
(tc/select-columns [:label :document])
(tc/order-by :document))

prediction-test
(->
(ml/predict test-review-clean model)
(tc/select-columns [:label :document])
(tc/order-by :document))

]


(is (< 0.95
(loss/classification-accuracy
(mapv int (:label prediction))
(vec trueth))))))
(mapv int (:label prediction-train))
(vec trueth-train))))

(is (< 0.55
(loss/classification-accuracy
(mapv int (:label prediction-test))
(vec trueth-test))))
))

0 comments on commit ee5256b

Please sign in to comment.