@@ -660,7 +660,88 @@ <h2 id="the-math">The math</h2>
660
660
\]</ div >
661
661
< p > When you have the final values from your derivative calculation, you can use it in the gradient descent equation and update the weights and bias.</ p >
662
662
< h2 id ="the-code "> The code</ h2 >
663
- < p > Coming soon</ p >
663
+ < p > The data used here is the < a href ="https://www.kaggle.com/datasets/uciml/breast-cancer-wisconsin-data "> Breast Cancer Wisconsin (Diagnostic) Data Set</ a > which has bee modified to look like < a href ="https://gitlab.com/adwaithrajesh/linear-ml-test/-/blob/main/data/bcancer.csv "> this</ a > , where we
664
+ don't have id's and M=0, and B=1</ p >
665
+ < div class ="highlight "> < pre > < span > </ span > < code > < span class ="cp "> #define INCLUDE_MAT_CONVERSIONS</ span >
666
+ < span class ="cp "> #include</ span > < span class ="w "> </ span > < span class ="cpf "> "ds/mat.h"</ span >
667
+ < span class ="cp "> #include</ span > < span class ="w "> </ span > < span class ="cpf "> "ml/logisticregress.h"</ span >
668
+ < span class ="cp "> #include</ span > < span class ="w "> </ span > < span class ="cpf "> "model/metrics.h"</ span >
669
+ < span class ="cp "> #include</ span > < span class ="w "> </ span > < span class ="cpf "> "model/train_test_split.h"</ span >
670
+ < span class ="cp "> #include</ span > < span class ="w "> </ span > < span class ="cpf "> "parsers/csv.h"</ span >
671
+
672
+ < span class ="kt "> int</ span > < span class ="w "> </ span > < span class ="nf "> main</ span > < span class ="p "> (</ span > < span class ="kt "> void</ span > < span class ="p "> )</ span > < span class ="w "> </ span > < span class ="p "> {</ span >
673
+ < span class ="w "> </ span > < span class ="n "> CSV</ span > < span class ="w "> </ span > < span class ="o "> *</ span > < span class ="n "> csv_reader</ span > < span class ="w "> </ span > < span class ="o "> =</ span > < span class ="w "> </ span > < span class ="n "> csv_init</ span > < span class ="p "> (</ span > < span class ="mi "> 569</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="mi "> 31</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="sc "> ','</ span > < span class ="p "> );</ span >
674
+ < span class ="w "> </ span > < span class ="n "> csv_parse</ span > < span class ="p "> (</ span > < span class ="n "> csv_reader</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="s "> "data/bcancer.csv"</ span > < span class ="p "> );</ span >
675
+
676
+ < span class ="w "> </ span > < span class ="n "> Mat</ span > < span class ="w "> </ span > < span class ="o "> *</ span > < span class ="n "> X</ span > < span class ="w "> </ span > < span class ="o "> =</ span > < span class ="w "> </ span > < span class ="n "> csv_get_mat_slice</ span > < span class ="p "> (</ span > < span class ="n "> csv_reader</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="p "> (</ span > < span class ="n "> Slice</ span > < span class ="p "> ){</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="mi "> 31</ span > < span class ="p "> });</ span >
677
+ < span class ="w "> </ span > < span class ="n "> Mat</ span > < span class ="w "> </ span > < span class ="o "> *</ span > < span class ="n "> Y</ span > < span class ="w "> </ span > < span class ="o "> =</ span > < span class ="w "> </ span > < span class ="n "> csv_get_mat_slice</ span > < span class ="p "> (</ span > < span class ="n "> csv_reader</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="p "> (</ span > < span class ="n "> Slice</ span > < span class ="p "> ){</ span > < span class ="mi "> 0</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="mi "> 1</ span > < span class ="p "> });</ span >
678
+ < span class ="w "> </ span > < span class ="n "> Mat</ span > < span class ="w "> </ span > < span class ="o "> *</ span > < span class ="n "> X_train</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="o "> *</ span > < span class ="n "> X_test</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="o "> *</ span > < span class ="n "> Y_train</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="o "> *</ span > < span class ="n "> Y_test</ span > < span class ="p "> ;</ span >
679
+
680
+ < span class ="w "> </ span > < span class ="n "> train_test_split</ span > < span class ="p "> (</ span > < span class ="n "> X</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> Y</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="o "> &</ span > < span class ="n "> X_train</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="o "> &</ span > < span class ="n "> X_test</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="o "> &</ span > < span class ="n "> Y_train</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="o "> &</ span > < span class ="n "> Y_test</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="mf "> 0.3</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="mi "> 101</ span > < span class ="p "> );</ span >
681
+
682
+ < span class ="w "> </ span > < span class ="n "> logregress_set_max_iter</ span > < span class ="p "> (</ span > < span class ="mi "> 2000</ span > < span class ="p "> );</ span >
683
+ < span class ="w "> </ span > < span class ="n "> LogisticRegressionModel</ span > < span class ="w "> </ span > < span class ="o "> *</ span > < span class ="n "> model</ span > < span class ="w "> </ span > < span class ="o "> =</ span > < span class ="w "> </ span > < span class ="n "> logregress_init</ span > < span class ="p "> ();</ span >
684
+ < span class ="w "> </ span > < span class ="n "> logregress_fit</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> X_train</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> Y_train</ span > < span class ="p "> );</ span >
685
+
686
+ < span class ="w "> </ span > < span class ="c1 "> // printf("prediction: %lf\n", logregress_predict(model, (double[]){15.22, 30.62, 103.4, 716.9, ... , 0}, 30));</ span >
687
+ < span class ="w "> </ span > < span class ="n "> Array</ span > < span class ="w "> </ span > < span class ="o "> *</ span > < span class ="n "> preds</ span > < span class ="w "> </ span > < span class ="o "> =</ span > < span class ="w "> </ span > < span class ="n "> logregress_predict_many</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> X_test</ span > < span class ="p "> );</ span >
688
+ < span class ="w "> </ span > < span class ="n "> Array</ span > < span class ="w "> </ span > < span class ="o "> *</ span > < span class ="nb "> true</ span > < span class ="w "> </ span > < span class ="o "> =</ span > < span class ="w "> </ span > < span class ="n "> mat_get_col_arr</ span > < span class ="p "> (</ span > < span class ="n "> Y_test</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="mi "> 0</ span > < span class ="p "> );</ span >
689
+
690
+ < span class ="w "> </ span > < span class ="n "> logregress_print</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> );</ span >
691
+
692
+ < span class ="w "> </ span > < span class ="n "> printf</ span > < span class ="p "> (</ span > < span class ="s "> "confusion matrix: </ span > < span class ="se "> \n</ span > < span class ="s "> "</ span > < span class ="p "> );</ span >
693
+ < span class ="w "> </ span > < span class ="n "> Mat</ span > < span class ="w "> </ span > < span class ="o "> *</ span > < span class ="n "> conf_mat</ span > < span class ="w "> </ span > < span class ="o "> =</ span > < span class ="w "> </ span > < span class ="n "> model_confusion_matrix</ span > < span class ="p "> (</ span > < span class ="nb "> true</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> preds</ span > < span class ="p "> );</ span >
694
+ < span class ="w "> </ span > < span class ="n "> mat_print</ span > < span class ="p "> (</ span > < span class ="n "> conf_mat</ span > < span class ="p "> );</ span >
695
+
696
+ < span class ="w "> </ span > < span class ="n "> arr_free</ span > < span class ="p "> (</ span > < span class ="nb "> true</ span > < span class ="p "> );</ span >
697
+ < span class ="w "> </ span > < span class ="n "> arr_free</ span > < span class ="p "> (</ span > < span class ="n "> preds</ span > < span class ="p "> );</ span >
698
+ < span class ="w "> </ span > < span class ="n "> logregress_free</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> );</ span >
699
+ < span class ="w "> </ span > < span class ="n "> mat_free_many</ span > < span class ="p "> (</ span > < span class ="mi "> 7</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> X</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> Y</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> X_test</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> X_train</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> Y_test</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> Y_train</ span > < span class ="p "> ,</ span > < span class ="w "> </ span > < span class ="n "> conf_mat</ span > < span class ="p "> );</ span >
700
+ < span class ="w "> </ span > < span class ="n "> csv_free</ span > < span class ="p "> (</ span > < span class ="n "> csv_reader</ span > < span class ="p "> );</ span >
701
+ < span class ="p "> }</ span >
702
+ </ code > </ pre > </ div >
703
+ < div class ="highlight "> < pre > < span > </ span > < code > < span class ="go "> LogisticRegressionModel(bias: 0.5159147, loss: -12.4263621, weights: 0x5556e8a732c0)</ span >
704
+ < span class ="go "> weights:</ span >
705
+ < span class ="go "> 1546.6922009</ span >
706
+ < span class ="go "> 1139.6829595</ span >
707
+ < span class ="go "> 8552.1648900</ span >
708
+ < span class ="go "> 2522.0044946</ span >
709
+ < span class ="go "> 11.8724211</ span >
710
+ < span class ="go "> -19.3345598</ span >
711
+ < span class ="go "> -44.9646156</ span >
712
+ < span class ="go "> -18.4984994</ span >
713
+ < span class ="go "> 23.8378678</ span >
714
+ < span class ="go "> 10.1676564</ span >
715
+ < span class ="go "> 0.2338315</ span >
716
+ < span class ="go "> 103.3839701</ span >
717
+ < span class ="go "> -139.7864354</ span >
718
+ < span class ="go "> -4498.8563443</ span >
719
+ < span class ="go "> 0.2662770</ span >
720
+ < span class ="go "> -6.5798244</ span >
721
+ < span class ="go "> -8.6158697</ span >
722
+ < span class ="go "> -1.6938180</ span >
723
+ < span class ="go "> 1.6508702</ span >
724
+ < span class ="go "> -0.3857419</ span >
725
+ < span class ="go "> 1650.7843571</ span >
726
+ < span class ="go "> 1445.0283208</ span >
727
+ < span class ="go "> 8312.7672485</ span >
728
+ < span class ="go "> -4024.9280673</ span >
729
+ < span class ="go "> 13.2972726</ span >
730
+ < span class ="go "> -72.4527931</ span >
731
+ < span class ="go "> -111.8298475</ span >
732
+ < span class ="go "> -26.6204266</ span >
733
+ < span class ="go "> 28.0612275</ span >
734
+ < span class ="go "> 5.4099162</ span >
735
+ < span class ="go "> confusion matrix:</ span >
736
+ < span class ="go "> 57.00 10.00</ span >
737
+ < span class ="go "> 2.00 101.00</ span >
738
+ </ code > </ pre > </ div >
739
+ < p > Now, what does the confusion matrix generated by sklean look like.</ p >
740
+ < div class ="highlight "> < pre > < span > </ span > < code > < span class ="n "> array</ span > < span class ="p "> ([[</ span > < span class ="mi "> 59</ span > < span class ="p "> ,</ span > < span class ="mi "> 7</ span > < span class ="p "> ],</ span >
741
+ < span class ="p "> [</ span > < span class ="mi "> 3</ span > < span class ="p "> ,</ span > < span class ="mi "> 102</ span > < span class ="p "> ]])</ span >
742
+ </ code > </ pre > </ div >
743
+ < p > we are pretty close...
744
+ checkout the python implementation < a href ="https://gitlab.com/adwaithrajesh/linear-ml-test/-/blob/main/notebooks/log.ipynb "> here</ a > </ p >
664
745
665
746
666
747
0 commit comments