Skip to content

Commit b273b95

Browse files
Deployed b37fdf7 with MkDocs version: 1.5.3
1 parent 5fcd5d4 commit b273b95

File tree

3 files changed

+83
-2
lines changed

3 files changed

+83
-2
lines changed

ml/logregress/index.html

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,88 @@ <h2 id="the-math">The math</h2>
660660
\]</div>
661661
<p>When you have the final values from your derivative calculation, you can use it in the gradient descent equation and update the weights and bias.</p>
662662
<h2 id="the-code">The code</h2>
663-
<p>Coming soon</p>
663+
<p>The data used here is the <a href="https://www.kaggle.com/datasets/uciml/breast-cancer-wisconsin-data">Breast Cancer Wisconsin (Diagnostic) Data Set</a> which has bee modified to look like <a href="https://gitlab.com/adwaithrajesh/linear-ml-test/-/blob/main/data/bcancer.csv">this</a>, where we
664+
don't have id's and M=0, and B=1</p>
665+
<div class="highlight"><pre><span></span><code><span class="cp">#define INCLUDE_MAT_CONVERSIONS</span>
666+
<span class="cp">#include</span><span class="w"> </span><span class="cpf">&quot;ds/mat.h&quot;</span>
667+
<span class="cp">#include</span><span class="w"> </span><span class="cpf">&quot;ml/logisticregress.h&quot;</span>
668+
<span class="cp">#include</span><span class="w"> </span><span class="cpf">&quot;model/metrics.h&quot;</span>
669+
<span class="cp">#include</span><span class="w"> </span><span class="cpf">&quot;model/train_test_split.h&quot;</span>
670+
<span class="cp">#include</span><span class="w"> </span><span class="cpf">&quot;parsers/csv.h&quot;</span>
671+
672+
<span class="kt">int</span><span class="w"> </span><span class="nf">main</span><span class="p">(</span><span class="kt">void</span><span class="p">)</span><span class="w"> </span><span class="p">{</span>
673+
<span class="w"> </span><span class="n">CSV</span><span class="w"> </span><span class="o">*</span><span class="n">csv_reader</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">csv_init</span><span class="p">(</span><span class="mi">569</span><span class="p">,</span><span class="w"> </span><span class="mi">31</span><span class="p">,</span><span class="w"> </span><span class="sc">&#39;,&#39;</span><span class="p">);</span>
674+
<span class="w"> </span><span class="n">csv_parse</span><span class="p">(</span><span class="n">csv_reader</span><span class="p">,</span><span class="w"> </span><span class="s">&quot;data/bcancer.csv&quot;</span><span class="p">);</span>
675+
676+
<span class="w"> </span><span class="n">Mat</span><span class="w"> </span><span class="o">*</span><span class="n">X</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">csv_get_mat_slice</span><span class="p">(</span><span class="n">csv_reader</span><span class="p">,</span><span class="w"> </span><span class="p">(</span><span class="n">Slice</span><span class="p">){</span><span class="mi">1</span><span class="p">,</span><span class="w"> </span><span class="mi">31</span><span class="p">});</span>
677+
<span class="w"> </span><span class="n">Mat</span><span class="w"> </span><span class="o">*</span><span class="n">Y</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">csv_get_mat_slice</span><span class="p">(</span><span class="n">csv_reader</span><span class="p">,</span><span class="w"> </span><span class="p">(</span><span class="n">Slice</span><span class="p">){</span><span class="mi">0</span><span class="p">,</span><span class="w"> </span><span class="mi">1</span><span class="p">});</span>
678+
<span class="w"> </span><span class="n">Mat</span><span class="w"> </span><span class="o">*</span><span class="n">X_train</span><span class="p">,</span><span class="w"> </span><span class="o">*</span><span class="n">X_test</span><span class="p">,</span><span class="w"> </span><span class="o">*</span><span class="n">Y_train</span><span class="p">,</span><span class="w"> </span><span class="o">*</span><span class="n">Y_test</span><span class="p">;</span>
679+
680+
<span class="w"> </span><span class="n">train_test_split</span><span class="p">(</span><span class="n">X</span><span class="p">,</span><span class="w"> </span><span class="n">Y</span><span class="p">,</span><span class="w"> </span><span class="o">&amp;</span><span class="n">X_train</span><span class="p">,</span><span class="w"> </span><span class="o">&amp;</span><span class="n">X_test</span><span class="p">,</span><span class="w"> </span><span class="o">&amp;</span><span class="n">Y_train</span><span class="p">,</span><span class="w"> </span><span class="o">&amp;</span><span class="n">Y_test</span><span class="p">,</span><span class="w"> </span><span class="mf">0.3</span><span class="p">,</span><span class="w"> </span><span class="mi">101</span><span class="p">);</span>
681+
682+
<span class="w"> </span><span class="n">logregress_set_max_iter</span><span class="p">(</span><span class="mi">2000</span><span class="p">);</span>
683+
<span class="w"> </span><span class="n">LogisticRegressionModel</span><span class="w"> </span><span class="o">*</span><span class="n">model</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">logregress_init</span><span class="p">();</span>
684+
<span class="w"> </span><span class="n">logregress_fit</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">X_train</span><span class="p">,</span><span class="w"> </span><span class="n">Y_train</span><span class="p">);</span>
685+
686+
<span class="w"> </span><span class="c1">// printf(&quot;prediction: %lf\n&quot;, logregress_predict(model, (double[]){15.22, 30.62, 103.4, 716.9, ... , 0}, 30));</span>
687+
<span class="w"> </span><span class="n">Array</span><span class="w"> </span><span class="o">*</span><span class="n">preds</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">logregress_predict_many</span><span class="p">(</span><span class="n">model</span><span class="p">,</span><span class="w"> </span><span class="n">X_test</span><span class="p">);</span>
688+
<span class="w"> </span><span class="n">Array</span><span class="w"> </span><span class="o">*</span><span class="nb">true</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">mat_get_col_arr</span><span class="p">(</span><span class="n">Y_test</span><span class="p">,</span><span class="w"> </span><span class="mi">0</span><span class="p">);</span>
689+
690+
<span class="w"> </span><span class="n">logregress_print</span><span class="p">(</span><span class="n">model</span><span class="p">);</span>
691+
692+
<span class="w"> </span><span class="n">printf</span><span class="p">(</span><span class="s">&quot;confusion matrix: </span><span class="se">\n</span><span class="s">&quot;</span><span class="p">);</span>
693+
<span class="w"> </span><span class="n">Mat</span><span class="w"> </span><span class="o">*</span><span class="n">conf_mat</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">model_confusion_matrix</span><span class="p">(</span><span class="nb">true</span><span class="p">,</span><span class="w"> </span><span class="n">preds</span><span class="p">);</span>
694+
<span class="w"> </span><span class="n">mat_print</span><span class="p">(</span><span class="n">conf_mat</span><span class="p">);</span>
695+
696+
<span class="w"> </span><span class="n">arr_free</span><span class="p">(</span><span class="nb">true</span><span class="p">);</span>
697+
<span class="w"> </span><span class="n">arr_free</span><span class="p">(</span><span class="n">preds</span><span class="p">);</span>
698+
<span class="w"> </span><span class="n">logregress_free</span><span class="p">(</span><span class="n">model</span><span class="p">);</span>
699+
<span class="w"> </span><span class="n">mat_free_many</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span><span class="w"> </span><span class="n">X</span><span class="p">,</span><span class="w"> </span><span class="n">Y</span><span class="p">,</span><span class="w"> </span><span class="n">X_test</span><span class="p">,</span><span class="w"> </span><span class="n">X_train</span><span class="p">,</span><span class="w"> </span><span class="n">Y_test</span><span class="p">,</span><span class="w"> </span><span class="n">Y_train</span><span class="p">,</span><span class="w"> </span><span class="n">conf_mat</span><span class="p">);</span>
700+
<span class="w"> </span><span class="n">csv_free</span><span class="p">(</span><span class="n">csv_reader</span><span class="p">);</span>
701+
<span class="p">}</span>
702+
</code></pre></div>
703+
<div class="highlight"><pre><span></span><code><span class="go">LogisticRegressionModel(bias: 0.5159147, loss: -12.4263621, weights: 0x5556e8a732c0)</span>
704+
<span class="go">weights:</span>
705+
<span class="go">1546.6922009</span>
706+
<span class="go">1139.6829595</span>
707+
<span class="go">8552.1648900</span>
708+
<span class="go">2522.0044946</span>
709+
<span class="go">11.8724211</span>
710+
<span class="go">-19.3345598</span>
711+
<span class="go">-44.9646156</span>
712+
<span class="go">-18.4984994</span>
713+
<span class="go">23.8378678</span>
714+
<span class="go">10.1676564</span>
715+
<span class="go">0.2338315</span>
716+
<span class="go">103.3839701</span>
717+
<span class="go">-139.7864354</span>
718+
<span class="go">-4498.8563443</span>
719+
<span class="go">0.2662770</span>
720+
<span class="go">-6.5798244</span>
721+
<span class="go">-8.6158697</span>
722+
<span class="go">-1.6938180</span>
723+
<span class="go">1.6508702</span>
724+
<span class="go">-0.3857419</span>
725+
<span class="go">1650.7843571</span>
726+
<span class="go">1445.0283208</span>
727+
<span class="go">8312.7672485</span>
728+
<span class="go">-4024.9280673</span>
729+
<span class="go">13.2972726</span>
730+
<span class="go">-72.4527931</span>
731+
<span class="go">-111.8298475</span>
732+
<span class="go">-26.6204266</span>
733+
<span class="go">28.0612275</span>
734+
<span class="go">5.4099162</span>
735+
<span class="go">confusion matrix:</span>
736+
<span class="go"> 57.00 10.00</span>
737+
<span class="go"> 2.00 101.00</span>
738+
</code></pre></div>
739+
<p>Now, what does the confusion matrix generated by sklean look like.</p>
740+
<div class="highlight"><pre><span></span><code><span class="n">array</span><span class="p">([[</span> <span class="mi">59</span><span class="p">,</span> <span class="mi">7</span><span class="p">],</span>
741+
<span class="p">[</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">102</span><span class="p">]])</span>
742+
</code></pre></div>
743+
<p>we are pretty close...
744+
checkout the python implementation <a href="https://gitlab.com/adwaithrajesh/linear-ml-test/-/blob/main/notebooks/log.ipynb">here</a></p>
664745

665746

666747

0 commit comments

Comments
 (0)