Skip to content

Commit

Permalink
made work with latest metamorph
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Nov 2, 2024
1 parent d8f4260 commit bdfed9b
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 82 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ example/results.nippy
.clj-kondo
.calva
.lsp
repeatedAbstrcats.txt
6 changes: 3 additions & 3 deletions deps.edn
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
ml.dmlc/xgboost4j_2.12 {:mvn/version "2.1.1"}
;ml.dmlc/xgboost4j-spark_2.12 {:mvn/version "2.1.1"} ;; what for ??
org.scicloj/metamorph.ml {:git/url "https://github.com/scicloj/metamorph.ml"
:git/sha "523e1d40d77959973cb3c5687eeff3b945a21818"
}
:git/sha "07003e740c303e56c3961a689a54f7ffdd0c64e7"}
;{:mvn/version "0.9.0"}

com.github.haifengl/smile-core {:mvn/version "2.6.0"}
Expand All @@ -15,7 +14,8 @@
org.slf4j/slf4j-log4j12]}
pppmap/pppmap {:mvn/version "1.0.0"}}

:aliases {
:aliases {:exp {:extra-paths "exp"
:jvm-opts ["-Xms12g" ]}
:codox {:extra-deps {codox/codox {:mvn/version "0.10.7"}
codox-theme-rdash/codox-theme-rdash {:mvn/version "0.1.2"}}
:exec-fn codox.main/generate-docs
Expand Down
117 changes: 117 additions & 0 deletions exp/exp.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
(ns exp
(:require [clojure.data.csv :as csv]
[clojure.java.io :as io]
[clojure.string :as str]
[scicloj.metamorph.ml.loss :as loss]
[scicloj.metamorph.ml.text :as text]
[scicloj.ml.xgboost :as xgboost]
[tablecloth.api :as tc]
[tech.v3.dataset.column-filters :as cf])
(:import [java.util.zip GZIPInputStream]
[ml.dmlc.xgboost4j.java XGBoost]))
(def max-lines 1000) ; fails with 10000

(defn deterministic-shuffle
[^java.util.Collection coll seed]
(let [al (java.util.ArrayList. coll)
rng (java.util.Random. seed)]
(java.util.Collections/shuffle al rng)
(clojure.lang.RT/vector (.toArray al))))

(let [

_ (println :slurp)
ds
(->
(text/->tidy-text (io/reader (io/input-stream "repeatedAbstrcats.txt"))
line-seq
(fn [line]
[line
(rand-int 5)])
#(str/split % #" ")
:max-lines max-lines
:skip-lines 1
:datatype-token-idx :int32)
:datasets
first

(tc/drop-rows #(= "" (:term %)))
(tc/drop-missing))

_ (def ds ds)

;(tc/select-rows ds #(= 603 (:document %)))

rnd-indexes (-> (range max-lines) (deterministic-shuffle 123))
rnd-indexes-train (take (* max-lines 0.8) rnd-indexes)
rnd-indexes-test (take-last (* max-lines 0.2) rnd-indexes)

ds-train (tc/inner-join (tc/dataset {:document rnd-indexes-train}) ds [:document])
ds-test (tc/inner-join (tc/dataset {:document rnd-indexes-test}) ds [:document])

_ (def ds-train ds-train)
_ (def ds-test ds-test)
_ (tc/select-missing ds-train)
_ (println :->term-frequency)
bow-train
(-> ds-train
text/->tfidf
(tc/rename-columns {:meta :label}))

_ (tc/select-missing bow-train)
bow-test
(-> ds-test
text/->tfidf
(tc/rename-columns {:meta :label}))


_ (println :to-matrix)

_ (def bow-train bow-train)
m-train (xgboost/tidy-text-bow-ds->dmatrix (cf/feature bow-train)
(tc/select-columns bow-train [:label])
:tfidf)
m-test (xgboost/tidy-text-bow-ds->dmatrix (cf/feature bow-test)
(tc/select-columns bow-test [:label])
:tfidf)

_ (println :train)
model
(xgboost/train-from-dmatrix
m-train
["term"]
["label"]
{:num-class 5
:validate-parameters "true"
:seed 123
:verbosity 0}
{}
"multi:softmax")

booster
(XGBoost/loadModel
(java.io.ByteArrayInputStream. (:model-data model)))

_ (println :predict)
predition-train
(->>
(.predict booster m-train)
(map #(int (first %))))

predition-test
(->>
(.predict booster m-test)
(map #(int (first %))))

train-accuracy
(loss/classification-accuracy
(float-array predition-train)
(.getLabel m-train))

test-accuracy
(loss/classification-accuracy
(float-array predition-test)
(.getLabel m-test))]

(println :train-accuracy train-accuracy)
(println :test-accuracy test-accuracy))
8 changes: 6 additions & 2 deletions src/scicloj/ml/xgboost.clj
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that


(defn tidy-text-bow-ds->dmatrix [feature-ds target-ds text-feature-column]
(println :n-features (tc/row-count feature-ds))
(let [ds (if (some? target-ds)
(assoc feature-ds :label (:label target-ds))
feature-ds)
Expand All @@ -214,13 +215,16 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that
#(map zero-baseddocs-map (:document %))))
sparse-features
(-> bow-zeroed
(tc/select-columns [:document :term-idx text-feature-column])
(tc/select-columns [:document :token-idx text-feature-column])
(tc/rows))

n-col (inc (apply max (bow-zeroed :term-idx)))
n-col (inc (apply max (bow-zeroed :token-idx)))

csr (csr/->csr sparse-features)

_ (println :max-column-index+1 (inc (apply max (:column-indices csr))))


labels
(->
bow-zeroed
Expand Down
7 changes: 5 additions & 2 deletions src/scicloj/ml/xgboost/csr.clj
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
[ tech.v3.datatype :as dt]
))

(defn- add-to-csr [csr row col value]
(set! *warn-on-reflection* true)
(set! *unchecked-math* :warn-on-boxed)

(defn- add-to-csr [csr ^long row ^long col ^double value]
(if (zero? value)
csr
(let [new-values (conj (:values csr) (float value))
Expand Down Expand Up @@ -43,7 +46,7 @@


(defn ->dense [csr rows cols]
(for [i (range rows)]
(for [^long i (range rows)]
(let [row-start (nth (:row-pointers csr) i)
row-end (nth (:row-pointers csr) (inc i))]
(for [j (range cols)]
Expand Down
134 changes: 78 additions & 56 deletions test/scicloj/ml/text_test.clj
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
(ns scicloj.ml.text-test
(:require [clojure.data.csv :as csv]
[clojure.java.io :as io]
[clojure.string :as str]
[clojure.test :refer [deftest is]]
[scicloj.metamorph.ml.loss :as loss]
[scicloj.metamorph.ml.text :as text]
[scicloj.ml.xgboost :as xgboost]
[scicloj.ml.xgboost.csr :as csr]
[tablecloth.api :as tc]
[tablecloth.column.api :as tcc]
[tech.v3.dataset.column-filters :as cf])
(:import [java.util.zip GZIPInputStream]
[ml.dmlc.xgboost4j.java XGBoost]
[ml.dmlc.xgboost4j.java DMatrix DMatrix$SparseType]))
(:require
[clojure.data.csv :as csv]
[clojure.java.io :as io]
[clojure.set :as set]
[clojure.string :as str]
[clojure.test :refer [deftest is]]
[scicloj.metamorph.ml.loss :as loss]
[scicloj.metamorph.ml.text :as text]
[scicloj.ml.xgboost :as xgboost]
[scicloj.ml.xgboost.csr :as csr]
[tablecloth.api :as tc]
[tablecloth.column.api :as tcc]
[tech.v3.dataset.column-filters :as cf])
(:import
[java.util.zip GZIPInputStream]
[ml.dmlc.xgboost4j.java XGBoost]
[ml.dmlc.xgboost4j.java DMatrix DMatrix$SparseType]))


(defn deterministic-shuffle
Expand All @@ -23,18 +26,22 @@
(clojure.lang.RT/vector (.toArray al))))

(deftest reviews-accuracy-sparse-matrix-classification
(let [ds
(->
(text/->tidy-text (io/reader (GZIPInputStream. (io/input-stream "test/data/reviews.csv.gz")))
(fn [line]
(let [splitted (first
(csv/read-csv line))]
[(first splitted)
(dec (Integer/parseInt (second splitted)))]))
#(str/split % #" ")
:max-lines 1000
:skip-lines 1)
(tc/rename-columns {:meta :label})
(let [tidy
(text/->tidy-text (io/reader (GZIPInputStream. (io/input-stream "test/data/reviews.csv.gz")))
line-seq
(fn [line]
(let [splitted (first
(csv/read-csv line))]
[(first splitted)
(dec (Integer/parseInt (second splitted)))]))
#(str/split % #" ")
:max-lines 1000
:skip-lines 1)

ds
(-> tidy
:datasets first
;
(tc/drop-rows #(= "" (:term %)))
(tc/drop-missing))

Expand All @@ -45,14 +52,18 @@
ds-train (tc/left-join (tc/dataset {:document rnd-indexes-train}) ds [:document])
ds-test (tc/left-join (tc/dataset {:document rnd-indexes-test}) ds [:document])


bow-train
(-> ds-train
text/->term-frequency

text/->tfidf
(tc/rename-columns {:meta :label})
)

bow-test
(-> ds-test
text/->term-frequency
text/->tfidf
(tc/rename-columns {:meta :label})
)


Expand Down Expand Up @@ -111,31 +122,43 @@

(deftest small-text

(let [ds
(->
(text/->tidy-text (io/reader "test/data/small_text.csv")
(fn [line]
(let [splitted (first
(csv/read-csv line))]
(vector
(first splitted)
(dec (Integer/parseInt (second splitted))))))
#(str/split % #" ")
:max-lines 10000
:skip-lines 1)
(tc/rename-columns {:meta :label}))
(let [tidy-result
(text/->tidy-text (io/reader "test/data/small_text.csv")
line-seq
(fn [line]
(let [splitted (first
(csv/read-csv line))]
(vector
(first splitted)
(dec (Integer/parseInt (second splitted))))))
#(str/split % #" ")
:max-lines 10000
:skip-lines 1)

tidy-ds

(->
tidy-result
:datasets
first)


id->token
(-> tidy-result :token-lookup-table set/map-invert)

bow
(-> ds text/->term-frequency)

(->
tidy-ds
text/->tfidf
(tc/rename-columns {:meta :label}))

sparse-features
(-> bow
(tc/select-columns [:document :term-idx :term-count])
(tc/select-columns [:document :token-idx :token-count])
(tc/rows))

n-rows (inc (apply tcc/max (bow :document)))
n-col (inc (apply max (bow :term-idx)))
n-col (inc (apply max (bow :token-idx)))

csr
(csr/->csr sparse-features)
Expand Down Expand Up @@ -167,8 +190,7 @@
["word"]
["label"]
{:num-class 2
:verbosity 0
}
:verbosity 0}
{}
"multi:softprob")

Expand All @@ -180,25 +202,25 @@
(.predict booster m)]

(is (= ["I", "like", "fish", "and", "you", "the", "fish", "Do", "you", "like", "me", "?"]
(:term ds)))
(map id->token (:token-idx tidy-ds))))




(is (= [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4] (ds :term-index)))
(is (= [1, 2, 3, 4, 5, 6, 3, 7, 5, 2, 8, 9] (tidy-ds :token-idx)))

(is (= [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1] (ds :document)))
(is (= [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1] (tidy-ds :document)))

(is (= [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1] (ds :label)))
(is (= [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1] (tidy-ds :meta)))

(is (=
[1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1]
;[1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1]
(:term-count bow)))
(:token-count bow)))


(is (=
[[0 1 1] [0 2 1] [1 2 1] [0 3 2] [0 4 1] [0 5 1] [1 5 1] [0 6 1] [1 7 1] [1 8 1] [1 9 1]]
;[[0 1 1] [0 2 1] [0 3 2] [0 4 1] [0 5 1] [0 6 1] [1 7 1] [1 5 1] [1 2 1] [1 8 1] [1 9 1]]
sparse-features))
[[0 5 1] [0 6 1] [0 1 1] [0 3 2] [0 4 1] [0 2 1] [1 5 1] [1 9 1] [1 7 1] [1 8 1] [1 2 1]]
sparse-features))

(is (= 2 n-rows))
(is (= 10 n-col))
Expand Down
Loading

0 comments on commit bdfed9b

Please sign in to comment.