Skip to content

Commit

Permalink
added handling of tid-text sparse columns
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Oct 14, 2024
1 parent aa01279 commit 898f92c
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 29 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# ConstantChangeLog

## unreleased
- fixed issue #1


## 6.1.0

Upgrade to xgboost4j_2.12 2.1.1
Expand Down
38 changes: 28 additions & 10 deletions src/scicloj/ml/xgboost.clj
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
[tech.v3.datatype :as dtype]
[tech.v3.datatype.errors :as errors]
[tech.v3.tensor :as dtt]
[scicloj.ml.xgboost.csr :as csr])
[scicloj.ml.xgboost.csr :as csr]
[scicloj.metamorph.ml.text :as text])
(:import [java.io ByteArrayInputStream ByteArrayOutputStream]
[java.util LinkedHashMap Map]
[ml.dmlc.xgboost4j LabeledPoint]
Expand Down Expand Up @@ -198,8 +199,18 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that
nil))


(defn tidy-text-bow-ds->dmatrix [bow]
(let [zero-baseddocs-map
(defn tidy-text-bow-ds->dmatrix [feature-ds target-ds]
(def feature-ds feature-ds)
(def target-ds target-ds)

;(-> feature-ds :word .data .data)
;(:label target-ds)

(let [ds (if (some? target-ds)
(assoc feature-ds :label (:label target-ds))
feature-ds)
bow (text/add-word-idx ds)
zero-baseddocs-map
(zipmap
(-> bow :document distinct)
(range))
Expand Down Expand Up @@ -230,7 +241,9 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that
(float-array (:values csr))
DMatrix$SparseType/CSR
n-col)]
(.setLabel m (float-array labels))
(def labels labels)
(when target-ds
(.setLabel m (float-array labels)))
m))


Expand Down Expand Up @@ -287,11 +300,18 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that

(defn ->dmatrix [feature-ds target-ds sparse-column n-sparse-columns]
(if sparse-column
(sparse-feature->dmatrix feature-ds target-ds sparse-column n-sparse-columns)
(if (= (-> feature-ds (get sparse-column) first class)
SparseArray)
(sparse-feature->dmatrix feature-ds target-ds sparse-column n-sparse-columns)
(tidy-text-bow-ds->dmatrix feature-ds target-ds)

)

(dataset->dmatrix feature-ds target-ds)))




(defn- thaw-model
[model-data]
(-> (if (map? model-data)
Expand Down Expand Up @@ -393,12 +413,12 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that
(ds-mod/inference-target-label-map label-ds))]
(train-from-dmatrix train-dmat feature-cnames target-cnames options label-map objective)))


(defn- predict
[feature-ds thawed-model {:keys [target-columns target-categorical-maps options]}]
(let [sparse-column-or-nil (:sparse-column options)
dmatrix (->dmatrix feature-ds nil sparse-column-or-nil (:n-sparse-columns options))
prediction (.predict ^Booster thawed-model dmatrix)

predict-tensor (->> prediction
(dtt/->tensor))
target-cname (first target-columns)]
Expand All @@ -407,17 +427,15 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that
(if (multiclass-objective? (options->objective options))
(->
(model/finalize-classification predict-tensor
(ds/row-count feature-ds)
target-cname
target-categorical-maps)

(tech.v3.dataset.modelling/probability-distributions->label-column
(first target-columns))
(tech.v3.dataset.modelling/probability-distributions->label-column target-cname)
(ds/update-column (first target-columns)
#(vary-meta % assoc :column-type :prediction)))
(model/finalize-regression predict-tensor target-cname))))




(defn- explain
Expand Down
15 changes: 11 additions & 4 deletions test/scicloj/ml/text_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
[scicloj.ml.xgboost.csr :as csr]
[tablecloth.api :as tc]
[tablecloth.column.api :as tcc]
[scicloj.metamorph.ml :as ml])
[scicloj.metamorph.ml :as ml]
[tech.v3.dataset.column-filters :as cf]
[tech.v3.dataset :as ds])
(:import [java.util.zip GZIPInputStream]
[ml.dmlc.xgboost4j.java XGBoost]
[ml.dmlc.xgboost4j.java DMatrix DMatrix$SparseType]))
Expand All @@ -26,7 +28,7 @@
[(first splitted)
(dec (Integer/parseInt (second splitted)))]))
#(str/split % #" ")
:max-lines 10000
:max-lines 1000
:skip-lines 1)
(tc/rename-columns {:meta :label})
(tc/drop-rows #(= "" (:word %)))
Expand All @@ -48,8 +50,10 @@
text/->term-frequency
text/add-word-idx)

m-train (xgboost/tidy-text-bow-ds->dmatrix bow-train)
m-test (xgboost/tidy-text-bow-ds->dmatrix bow-test)
m-train (xgboost/tidy-text-bow-ds->dmatrix (cf/feature bow-train)
(tc/select-columns bow-train [:label]) )
m-test (xgboost/tidy-text-bow-ds->dmatrix (cf/feature bow-test)
(tc/select-columns bow-test [:label]))

model
(xgboost/train-from-dmatrix
Expand Down Expand Up @@ -84,6 +88,9 @@
(float-array predition-test)
(.getLabel m-test))]

(println :train-accuracy train-accuracy)
(println :test-accuracy test-accuracy)

(is (< 0.95 train-accuracy))
(is (< 0.54 test-accuracy))))

Expand Down
29 changes: 14 additions & 15 deletions test/scicloj/ml/xgboost_test.clj
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
(ns scicloj.ml.xgboost-test
(:require [clojure.test :refer [deftest is]]
[fastmath.protocols :as protocols]

[fastmath.vector :as vec]
(:require [clojure.data.csv :as csv]
[clojure.java.io :as io]
[clojure.string :as str]
[clojure.test :refer [deftest is]]
[scicloj.metamorph.ml :as ml]
[scicloj.metamorph.ml.gridsearch :as ml-gs]
[scicloj.metamorph.ml.loss :as loss]
[scicloj.metamorph.ml.text :as text]
[scicloj.metamorph.ml.verify :as verify]
[scicloj.ml.smile.discrete-nb :as nb]
[scicloj.ml.smile.nlp :as nlp]
[scicloj.ml.xgboost]
[tablecloth.api :as tc]
[tablecloth.column.api :as tcc]
[tech.v3.dataset :as ds]
[tech.v3.dataset.categorical :as ds-cat]
[tech.v3.dataset.column-filters :as cf]
[tech.v3.dataset.modelling :as ds-mod]
[tech.v3.datatype :as dtype]
[tech.v3.datatype.functional :as dfn]
[tech.v3.datatype :as dt]))
[tech.v3.datatype :as dtype]
[tech.v3.datatype.functional :as dfn])
(:import [java.util.zip GZIPInputStream]))


(deftest basic
Expand Down Expand Up @@ -77,7 +77,7 @@
:sparse-column :bow-sparse
:n-sparse-columns 100})


explanation (ml/explain model)
test-ds (ds/head reviews 100)
prediction (ml/predict test-ds model)
Expand All @@ -88,13 +88,12 @@
:Score)
(-> test-ds
(ds-cat/reverse-map-categorical-xforms)
:Score))
]
(is ( > train-acc 0.97))))
:Score))]
(is (> train-acc 0.97))))


(deftest iris
(let [ src-ds (ds/->dataset "test/data/iris.csv")
(deftest iris
(let [src-ds (ds/->dataset "test/data/iris.csv")
ds (-> src-ds
(ds/categorical->number cf/categorical)
(ds-mod/set-inference-target "species"))
Expand All @@ -104,7 +103,7 @@
test-ds (:test-ds split-data)
model (ml/train train-ds {:validate-parameters "true"
:seed 123
:verbosity 1
:verbosity 0
:model-type :xgboost/classification})
predictions (ml/predict test-ds model)
loss
Expand Down

0 comments on commit 898f92c

Please sign in to comment.