Skip to content

Commit

Permalink
attached cat map to prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Sep 15, 2024
1 parent f563102 commit cfd5e60
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 16 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
unreleased

- fixed #19

0.8.2
-fixed metric bug

Expand Down
2 changes: 1 addition & 1 deletion deps.edn
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{:deps {org.clojure/clojure {:mvn/version "1.11.2"}
{:deps {org.clojure/clojure {:mvn/version "1.12.0"}
scicloj/metamorph {:mvn/version "0.2.4"}
pppmap/pppmap {:mvn/version "1.0.0"}
scicloj/tablecloth {:mvn/version "7.029.2"}
Expand Down
24 changes: 20 additions & 4 deletions src/scicloj/metamorph/ml.clj
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,18 @@
;; (vec (distinct simple-predicted-values))
;; (-> target-cat-maps-from-predict vals first :lookup-table)))

(defn- assoc-categorical-maps [pred-ds target-categorical-map target-columns]
(if target-categorical-map
(reduce (fn [ds col]
(ds/assoc-metadata
ds
[col]
:categorical-map (get target-categorical-map col)))
pred-ds
target-columns
)
pred-ds))


(defn predict
"Predict returns a dataset with only the predictions in it.
Expand All @@ -682,14 +694,18 @@
(let [{:keys [predict-fn] :as model-def} (options->model-def (:options model))
feature-ds (ds/select-columns dataset (:feature-columns model))
thawed-model (thaw-model model model-def)
pred-ds (predict-fn feature-ds
thawed-model
model)]
pred-ds
(->
(predict-fn feature-ds
thawed-model
model)
(assoc-categorical-maps
(:target-categorical-maps model)
(:target-columns model)))]

(warn-inconsitent-maps model pred-ds)
pred-ds))


(defn loglik [model y yhat]

(let [loglik-fn
Expand Down
2 changes: 1 addition & 1 deletion src/scicloj/metamorph/ml/classification.clj
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,6 @@
(fn [] (rand-nth (:distinct-labels model-data))))))]

(ds/new-dataset [(ds/new-column target-column-name dummy-labels {:column-type :prediction
:categorical-map (get target-categorical-maps target-column-name)})])))
})])))

{:glance-fn (fn [_] (ds/->dataset {:npar 0}))})
7 changes: 6 additions & 1 deletion test/scicloj/metamorph/classification_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,16 @@

prediction (ml/predict ds model)]

(is (=
["setosa","setosa"]
(-> prediction
(ds-cat/reverse-map-categorical-xforms)
(tc/head 2)
:species)))

(is (= (:species prediction) (repeat 150 0)))))



(deftest dummy-classification-majority []
(let [ds (toydata/breast-cancer-ds)
model (ml/train ds {:model-type :metamorph.ml/dummy-classifier
Expand Down
58 changes: 49 additions & 9 deletions test/scicloj/metamorph/ml_test.clj
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
(ns scicloj.metamorph.ml-test
(:require
[scicloj.ml.smile.classification]
[clojure.test :as t :refer [deftest is]]
[confuse.multi-class-metrics :as mcm]
[malli.core :as m]
[scicloj.metamorph.core :as morph]
[scicloj.metamorph.ml :as ml]
[scicloj.metamorph.ml.evaluation-handler
:as eval
:refer [qualify-pipelines]]
[scicloj.metamorph.ml.loss :as loss]
[scicloj.metamorph.ml.toydata :as toydata]
[scicloj.metamorph.ml.metrics]
[tablecloth.api :as tc]
[taoensso.nippy :as nippy]
[tech.v3.dataset :as ds]
[tech.v3.dataset.column-filters :as cf]
[tech.v3.dataset.categorical :as ds-cat]
Expand Down Expand Up @@ -51,7 +49,7 @@

(let [
predic-col (ds/new-column :species (repeat (tc/row-count feature-ds) 1)
{:categorical-map (get target-categorical-maps (first target-columns))
{
:column-type :prediction})
predict-ds (ds/new-dataset [predic-col])]

Expand Down Expand Up @@ -473,14 +471,14 @@
(->
(ds/->dataset {:x [0 1 ] :target ["x" "y"]})
(ds-mod/set-inference-target :target)
(ml/train {:model-type :test-model-float-predictions}))]


(ml/train {:model-type :test-model-float-predictions}))
prediction (ml/predict (ds/->dataset {:x [0]}) model)
]
(is (= [1.0]
(-> (ml/predict (ds/->dataset {:x [0]}) model) :species)))))
(-> prediction :species)))))


(deftest test-predict-striong
(deftest test-predict-string
(let [model
(->
(ds/->dataset {:x [0 1 ] :target ["x" "y"]})
Expand All @@ -492,4 +490,46 @@
(-> (ml/predict (ds/->dataset {:x [0]}) model) :species)))))


(deftest test-cat-reverse-float
(let [model
(->
(ds/->dataset {:x [0 1] :target ["x" "y"]})
(ds/categorical->number [:target])
(ds-mod/set-inference-target :target)
(ml/train {:model-type :smile.classification/logistic-regression}))
prediction (ml/predict (ds/->dataset {:x [0]}) model)]


(is
(= ["x"]
(-> prediction ds-cat/reverse-map-categorical-xforms :target)))
(is (= [0.0]
(-> prediction :target)))))


(deftest test-cat-reverse-int
(let [train
(->
(ds/->dataset {:x [0 1] :target ["x" "y"]})
(ds/categorical->number [:target] [] :int16)
(ds-mod/set-inference-target :target)
)

model
(-> train
(ml/train {:model-type :smile.classification/logistic-regression}))
prediction (ml/predict (ds/->dataset {:x [0]}) model)]

(is
(= ["x"]
(-> prediction ds-cat/reverse-map-categorical-xforms :target)))


;; TODO inconsistent
;; https://github.com/scicloj/scicloj.ml.smile/issues/16
(is (= [0 1]
(-> train :target seq)))
(is (= [0.0]
(-> prediction :target seq)))))


0 comments on commit cfd5e60

Please sign in to comment.