-
Notifications
You must be signed in to change notification settings - Fork 0
/
hcv_multiclass.Rmd
735 lines (562 loc) · 34.5 KB
/
hcv_multiclass.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
---
title: "BalgobinS.DA5030.Project.Rmd"
output:
pdf_document: default
html_document: default
date: "2023-08-08"
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```
# Predicting HCV Infection and Progrssion Using Machine Learning
### This data set contains the lab values of both blood donors and Hep C patients including demographic attributes like age and sex. It is taken from the UC Irvine repository. The target variable is categorical, and we are planning to predict whether an individual has HCV or progressed HCV (fibrosis/cirrhosis). We will be employing multi-class classification and grouping patients of progressed stages of hepatitis into one group as there is extreme class imbalance and trying to predict fibrosis and cirrhosis separately would not yield significant statistical power. Hepatitis C is a a viral infection that causes liver inflammation. Fibrosis occurs when where is a limited accumulation of scar tissue, and cirrhosis occurs when there is extensive fibrosis. Among those with a chronic HCV infection, 15-20% progress to end-stage liver disease. HCV remains a significant public health challenge, and in order to reap the benefits of novel therapies, we need a reduction in the undiagnosed population coupled with early diagnosis so that patients can be treated before experiencing the long term ramifications of HCV.
```{r loadlibraries, echo=FALSE, message=FALSE}
library(ggplot2)
library(ggpubr)
library(tidyverse)
library(dplyr)
library(class)
library(klaR)
library(caret)
library(rpart)
library(readxl)
library(psych)
library(rsample)
library(randomForest)
library(e1071)
library(pROC)
library(xgboost)
library(ROSE)
library(smotefamily)
library(DMwR)
```
## Load data
```{r loaddata}
# Load data from google drive URL
hcv_data <- read.csv("https://drive.google.com/uc?export=download&id=1IBnIVbW_uSiDxp_kGwkmhk2D3GVEXaQB")
```
## Explore data
```{r exploredata}
dim(hcv_data)
glimpse(hcv_data)
summary(hcv_data)
```
From the initial data exploration, we see that the target variable is "Category" and has 5 classes. There are 2 categorical variables, age and sex, and 10 continuous variables which represent the different lab tests and their values. The first column seems to be the patient ID, which we can drop. The mean and median are close for several variables, however the max values are quite far off, indicating a skewed distribution with outliers for some of the variables. ALB, CHOL, and PROT might be normally distributed. We will remove variable X as it is patient ID and not important, and also remove the rows that have 0s= suspected Blood Donor. These likely indicate patients who are suspected to have HCV infection and could likely be in any stage of HCV. This just adds noise to the data so we will remove those patients. Given the extreme class imbalance, we will group patients who have fibrosis and cirrhosis into one category called "Progressed HCV".
```{r removedata}
# Remove ID column
hcv_data <- subset(hcv_data, select = -X)
# Remove rows with "0s = suspected Blood Donor in category
hcv_data <- hcv_data %>% filter(Category != "0s=suspect Blood Donor")
# Combine Hep groups into one for binary classification
vals_to_replace <- c("2=Fibrosis", "3=Cirrhosis")
replacement_val <- c("Progressed HCV")
hcv_data <- hcv_data %>%
mutate(
Category = ifelse(Category %in% vals_to_replace, replacement_val, Category)
)
```
```{r distplots}
# Plot density for each variable
hist(hcv_data$Age)
hist(hcv_data$ALB)
hist(hcv_data$ALP)
hist(hcv_data$ALT)
hist(hcv_data$AST)
hist(hcv_data$BIL)
hist(hcv_data$CHE)
hist(hcv_data$CHOL)
hist(hcv_data$CREA)
hist(hcv_data$GGT)
hist(hcv_data$PROT)
```
Age is fairly normal with a slight right skew. ALB seems normally distributed. Most of ALP's values fall within a small range, however there is a small proportion of outliers. ALT, AST, BIL has a significant right skew. CHE has a normal distribution. CHOL has a fairly normal distribution. CREA has a significant right skew. GGT is right skewed. PROT is slightly left skewed.
## Identifying Outliers
```{r outliers}
# Function to calculate z score
calc_z_score <- function(x) {
mu <- mean(x, na.rm = TRUE)
sd <- sd(x, na.rm = TRUE)
(x - mu) / sd
}
# Calculate z_scores and find outliers for each column
# Age
z_scores <- calc_z_score(hcv_data$Age)
range(z_scores, na.rm=TRUE)
age_out <- which(abs(z_scores) > 2.5, arr.ind = TRUE)
length(age_out)
# ALB
z_scores <- calc_z_score(hcv_data$ALB)
range(z_scores, na.rm=TRUE)
median(z_scores, na.rm=TRUE)
alb_out <- which(abs(z_scores) > 4, arr.ind = TRUE)
length(alb_out)
# ALP
z_scores <- calc_z_score(hcv_data$ALP)
range(z_scores, na.rm=TRUE)
alp_out <- which(abs(z_scores) > 3, arr.ind = TRUE)
length(alp_out)
# ALT
z_scores <- calc_z_score(hcv_data$ALT)
range(z_scores, na.rm=TRUE)
alt_out <- which(abs(z_scores) > 4, arr.ind = TRUE)
length(alt_out)
# AST
z_scores <- calc_z_score(hcv_data$AST)
range(z_scores, na.rm=TRUE)
ast_out <- which(abs(z_scores) > 5, arr.ind = TRUE)
length(ast_out)
# BIL
z_scores <- calc_z_score(hcv_data$BIL)
range(z_scores, na.rm=TRUE)
bil_out <- which(abs(z_scores) > 2.5, arr.ind = TRUE)
length(bil_out)
# CHE
z_scores <- calc_z_score(hcv_data$CHE)
range(z_scores, na.rm=TRUE)
che_out <- which(abs(z_scores) > 3, arr.ind = TRUE)
length(che_out)
# CHOL
z_scores <- calc_z_score(hcv_data$CHOL)
range(z_scores, na.rm=TRUE)
chol_out <- which(abs(z_scores) > 3.5, arr.ind = TRUE)
length(chol_out)
hist(z_scores)
# CREA
z_scores <- calc_z_score(hcv_data$CREA)
range(z_scores, na.rm=TRUE)
crea_out <- which(abs(z_scores) > 2.5, arr.ind = TRUE)
length(crea_out)
hist(z_scores)
# GGT
z_scores <- calc_z_score(hcv_data$GGT)
range(z_scores, na.rm=TRUE)
ggt_out <- which(abs(z_scores) > 4, arr.ind = TRUE)
length(ggt_out)
hist(z_scores)
# PROT
z_scores <- calc_z_score(hcv_data$PROT)
range(z_scores, na.rm=TRUE)
prot_out <- which(abs(z_scores) > 3.5, arr.ind = TRUE)
length(prot_out)
hist(z_scores)
```
We looked at the z-scores, range of z-scores, and distribution of both the data and z_scores to identify outliers. We also used domain knowledge to assess whether we should remove certain outliers and which z-score threshold to set. For age, there are only 7 outliers with a z score threshold of 2.5. We chose not to remove these data points as they represent those who have an age a few years above 70 and this is important information as those who are older are more likely to be in later stages of HCV infection. There are only 2 outliers for ALB using a z score threshold of 4. Those represent one high value and one low value, which isn't necessarily correlated with HCV infection and could just indicate dehydration or some other issue. Studies have shown that AST, ALT, and ALP levels are significantly correlated with viral load of HCV. There are 3 outliers using a threshold of 3 for ALP. We won't remove these because these are for the patients with HCV and are important information. Using a threshold of 3 for ALT, we have 8 outliers with most of them for infected patients. To not lose valuable data, we won't remove these outliers. Same applies to AST. Using a threshold of 2.5 for BIL, we see there are 8 outliers. Since these represent high billirubin levels for cirrhosis patients, we will not remove these. High billirubin is usually linked to jaundice which can indicate severe liver disease/cirrhosis. There are not many cirrhosis data points so we don't want to lose data. Using a z score threshold of 3 for CHE, there are 9 outliers. 4 of these represent values in the normal range for CHE, and 5 of those represent abnormal values and are found in patients with HCV infection. Decreased CHE can be due to liver damage. The mean z score for CHE is 2.3. We will keep these outliers as they are important information, aren't too far from the mean z_score, and come from a normal distribution. For CHOL, its unclear what the values represent as there are different types of cholesterol. We won't remove outliers for CHOL given the lack of clarity around its relevance. The range of z scores for CREA has extreme variance. The outliers seem to be for patients who have cirrhosis, which makes sense. At later stages of liver disease, there may also be the co-morbidity of failure of kidneys to remove creatinine, causing increased levels. There are outliers for GGT as well as we get to the hepatitis and cirrhosis patients. We won't remove these few outliers either because literature suggests that the variance in GGT values can help differentiate between fibrosis and more extensive scarring. There are only 4 protein outliers using a z score threshold of 3.5. Not worth removing, especially since the data is normally distributed and we have a small dataset to begin with. In summary, there aren't enough outliers from each variable to warrant removing, some of the outliers are relevant to particular classes, and our data has a class imbalance issue. Therefore we will keep the outliers in the data.
## Correlation and Collinearity
```{r cor}
# Get numerical columns
numerical_cols <- hcv_data[, sapply(hcv_data, is.numeric)]
# Correlation matrix
cor(numerical_cols, use = "pairwise.complete.obs")
# Distribution and collinearity/correlation
pairs.panels(hcv_data)
```
We used a pairwise correlation test on the numerical columns to check for collinearity. There seems to be a weak to moderate correlation between GGT and AST and between ALB and PROT. PROT has a weak correlation with the target variable as well as the other variables aside from ALB. Based on domain knowledge, it seems to be far less important to liver function than the other lab values. We will remove PROT from the data. Age and sex also have weak correlation with the target variable, so we will remove those as well. We will not be using PCA as there is a non-linear relationship between features.
## Cleaning and Shaping Data
```{r missingdata}
# Remove PROT
hcv_data <- subset(hcv_data, select = -PROT)
hcv_data <- subset(hcv_data, select = -Age)
hcv_data <- subset(hcv_data, select = -Sex)
# Find the number of missing values in each column
missing_counts <- colSums(is.na(hcv_data))
missing_counts
```
We have some missing data for ALB, ALP, ALT, and CHOL. I don't want to lose information, and given that the data is very skewed, we don't want to take the mean/median for imputation, so we will make the features more normal through transformation before imputation.
```{r featuretransformation}
# Let's see if we can transform features to look more normal before imputation
# ALB
hist(log(hcv_data$ALB))
hcv_data$ALB <- log(hcv_data$ALB)
# ALP
hist(log(hcv_data$ALP))
hcv_data$ALP <- log(hcv_data$ALP)
# ALT
hist(log(hcv_data$ALT))
hcv_data$ALT <- log(hcv_data$ALT)
# AST
hist(log(hcv_data$AST))
hist(1/hcv_data$AST)
hcv_data$AST <- 1 / hcv_data$AST
# BIL
hist(log(hcv_data$BIL))
hcv_data$BIL <- log(hcv_data$BIL)
# CREA
hist(log(hcv_data$CREA))
hcv_data$CREA <- log(hcv_data$CREA)
# GGT
hist(log(hcv_data$GGT))
hcv_data$GGT <- log(hcv_data$GGT)
```
```{r aftertransformation}
pairs.panels(hcv_data)
```
We transformed the features that were not normally distributed using log transformation or inverse transformation, depending on which transformation produced a more normal distribution.
```{r imputation}
# Get numeric columns
numerical_cols <- sapply(hcv_data, is.numeric)
# Impute NAs with median
hcv_data[numerical_cols] <- apply(hcv_data[numerical_cols], 2, function(x) {
x[is.na(x)] <- median(x, na.rm=TRUE)
x
})
```
Given that the feature distributions were originally skewed and we did not want to lose data given the small sample size and class imbalance, we imputed missing values with the median.
```{r dummycodes}
# Factoring
hcv_data$Category <- as.factor(hcv_data$Category)
```
## Model Construction and Evaluation
### Downsampling and oversampling
```{r}
# Split the data into subsets
majority_data <- hcv_data[hcv_data$Category == '0=Blood Donor', ]
hepatitis_data <- hcv_data[hcv_data$Category == '1=Hepatitis', ]
progressed_hcv_data <- hcv_data[hcv_data$Category == 'Progressed HCV', ]
# Convert 'Category' to factor in minority datasets
hepatitis_data$Category <- as.factor(hepatitis_data$Category)
progressed_hcv_data$Category <- as.factor(progressed_hcv_data$Category)
# Downsample the majority class
set.seed(123)
num_to_sample <- 5 * (nrow(hepatitis_data) + nrow(progressed_hcv_data))
num_to_sample <- min(num_to_sample, nrow(majority_data))
downsampled_majority_data <- majority_data[sample(nrow(majority_data), num_to_sample), ]
set.seed(123)
# Manually oversample '1=Hepatitis'
oversample_size_hepatitis <- 5 * nrow(hepatitis_data) # Adjust the multiplier as needed
oversampled_hepatitis <- hepatitis_data[sample(1:nrow(hepatitis_data), size = oversample_size_hepatitis, replace = TRUE), ]
# Manually oversample 'Progressed HCV'
#oversample_size_progressed_hcv <- 5 * nrow(progressed_hcv_data) # Adjust the multiplier as needed
#oversampled_progressed_hcv <- progressed_hcv_data[sample(1:nrow(progressed_hcv_data), size = oversample_size_progressed_hcv, replace = TRUE), ]
# Combine the downsampled majority class with the oversampled minority class
balanced_data <- rbind(downsampled_majority_data, oversampled_hepatitis, progressed_hcv_data)
```
### Create training and validation sets
```{r trainvalsets}
# Fixed seed for reproducibility
set.seed(123)
# Randomize data
balanced_data <- balanced_data[sample(nrow(balanced_data)), ]
# Create a stratified random train-validation split
train_index <- createDataPartition(y = balanced_data$Category,
p = 0.8,
list = FALSE,
times = 1 )
# Training data
train_data <- balanced_data[train_index, ]
# Validation data
validation_data <- balanced_data[-train_index, ]
```
The data has a significant class imbalance. About 88% of the data falls under "Blood Donor" with only ~ 4% falling under "Hepatitis", and ~8% falling under "Progressed HCV. We randomize the dataset and use stratified sampling to split the data 80/20 while preserving class distribution in both the training and validation sets. We also downsample the majority class using a 5:1 ratio to reduce bias in overfitting the model to the majority class, and randomly oversample the hepatitis class due to the very small sample size. Because the progressed class has more variation, I chose not to oversample this class. SMOTE would be a preferable alternative to manual oversampling giving it's advanced statistical properties, but the hepatitis sample size was too small to use the popular SMOTE package.
### XGBoost
```{r xgboost}
# Train XGBoost model
# Set target variable
target_column <- "Category"
# Encoding for xgboost - convert to factor
train_data_xg <- train_data
validation_data_xg <- validation_data
#train_data_xg$Category <- as.numeric(as.factor(train_data_xg$Category)) - 1
#validation_data_xg$Category <- as.numeric(as.factor(validation_data_xg$Category)) -1
train_labels_numeric <- as.numeric(as.factor(train_data_xg$Category)) - 1
validation_labels_factor <- factor(validation_data_xg$Category, levels = c("0=Blood Donor", "1=Hepatitis", "Progressed HCV"))
# Remove target variable from data to get features
features <- names(train_data_xg)[names(train_data_xg) != target_column]
# Convert data to matrix format
train_matrix <- as.matrix(train_data_xg[, features])
validation_matrix <- as.matrix(validation_data_xg[, features])
# xgboost parameters for multi-class classification
params <- list(
objective = "multi:softprob",
num_class = length(unique(train_data_xg$Category)),
max_depth = 3
)
# Train the model
xgb_model <- xgboost(data = train_matrix,
label = train_labels_numeric,
params = params,
nrounds = 50)
# Make predictions on the validation data
xgb_predictions <- predict(xgb_model, newdata = validation_matrix)
# Convert predictions to class labels
max_prob_index <- matrix(xgb_predictions, ncol = length(unique(train_data_xg$Category)), byrow = TRUE)
xgb_class_labels <- apply(max_prob_index, 1, which.max) - 1
xgb_class_labels <- factor(xgb_class_labels, levels = 0:2, labels = c("0=Blood Donor", "1=Hepatitis", "Progressed HCV"))
# Confustion Matrix
confusion_matrix <- confusionMatrix(xgb_class_labels, validation_labels_factor)
confusion_matrix
# Save metrics
# Extract the table from the confusion matrix
cm_table <- confusion_matrix$table
# Initialize vectors to store the metrics for each class
precision <- numeric(length = ncol(cm_table))
recall <- numeric(length = ncol(cm_table))
f1_score <- numeric(length = ncol(cm_table))
# Calculate metrics for each class
for (i in 1:ncol(cm_table)) {
tp <- cm_table[i, i]
fp <- sum(cm_table[, i]) - tp
fn <- sum(cm_table[i, ]) - tp
tn <- sum(cm_table) - tp - fp - fn
precision[i] <- tp / (tp + fp)
recall[i] <- tp / (tp + fn)
f1_score[i] <- 2 * (precision[i] * recall[i]) / (precision[i] + recall[i])
}
# Output the results
xgb_metrics <- data.frame(
Class = colnames(cm_table),
Precision = precision,
Recall = recall,
F1_Score = f1_score
)
xgb_metrics
```
XGBoost falls under the category of gradient boosting. It can handle both regression and classification problems and is known for providing high predictive accuracy and handling complex relationships in data. It uses gradient boosting, essentially it is an ensemble learning technique that combines multiple weak learners to create a strong predictive model. It primarily uses decision trees as its base learners. It utilizes boosting by sequentially adding trees to the model. Each tree focuses on correcting the errors made by the previous model. We used a smaller number for the rounds and depth due to the small dataset and class imbalance. We used the holdout method to test this model. We use kappa, precision, recall, and F1 score to evaluate the model as these metrics are useful for imbalanced datasets.
A kappa value of 0.94 indicates a very good agreement between the model's predictions and the actual classes. The positive predictive value (PPV) and negative predictive value (NVP) are very high for for all classes, indicating the model's effectiveness in class prediction. There is excellent performance in classifying Blood Donor and Hepatitis cases, with perfect precision (100%) and very high recall. However, the model shows a lower precision for the Progressed HCV class, at 70%, although it achieves perfect recall. This suggests the model is particularly effective in identifying true Blood Donor and Hepatitis cases but prone to some false positives in detecting Progressed HCV. However, it is better to falsely classify a patient as progressed HCV and run additional labs than to miss a diagnosis.The F1 score which is the harmonic mean of the precision and recall is high for Blood Donor and Hepatitis, and moderate for Progressed HCV due to the lower precision. Overall, this is a strong model whose performance could potentially be improved by focusing on the precision for the Progressed HCV class, possibly through resampling or collecting more data.
### RandomForest
```{r randomforest}
# Train a bagged model using randomForest
set.seed(123)
bagged_model <- randomForest(Category ~ ., data = train_data)
predictions <- predict(bagged_model, newdata = validation_data)
# Confusion matrix
confusion_matrix <- confusionMatrix(predictions, validation_data$Category)
confusion_matrix
# Save metrics
# Extract the table from the confusion matrix
cm_table <- confusion_matrix$table
# Initialize vectors to store the metrics for each class
precision <- numeric(length = ncol(cm_table))
recall <- numeric(length = ncol(cm_table))
f1_score <- numeric(length = ncol(cm_table))
# Calculate metrics for each class
for (i in 1:ncol(cm_table)) {
tp <- cm_table[i, i]
fp <- sum(cm_table[, i]) - tp
fn <- sum(cm_table[i, ]) - tp
tn <- sum(cm_table) - tp - fp - fn
precision[i] <- tp / (tp + fp)
recall[i] <- tp / (tp + fn)
f1_score[i] <- 2 * (precision[i] * recall[i]) / (precision[i] + recall[i])
}
# Output the results
rf_metrics <- data.frame(
Class = colnames(cm_table),
Precision = precision,
Recall = recall,
F1_Score = f1_score
)
rf_metrics
```
The RandomForest model is versatile and robust. It can handle class imbalance and non-linearity effectively. It's ensemble nature mitigates overfitting, and works fairly well without extensive parameter tuning. We use the holdout method for testing and we use kappa, precision, recall, and F1 score to evaluate the model as these metrics are useful for imbalanced datasets.
The RandomForest model demonstrates strong performance in identifying Blood Donor and Hepatitis classes, as evidenced by the F1 score and kappa (0.93). In the Hepatitis class, there is very strong precision coupled with a high recall, suggesting every prediction made for Hepatitis is correct. The F1 score of 96% indicates efficient identification of Hepatitis cases with minimal false negatives. The precision, recall and F1 score are moderate for the Progressed HCV class, and have a few more occurrences of false positives and negatives, with 2 out of 10 predictions misclassified as HCV and 1 misclassified as Blood Donor. We may be able to improve this model by assigning class weights and assigning a larger weight to the minority class.
### SVM
```{r svm}
# Train SVM model
svm_model <- svm(Category ~ ., data = train_data, kernel = "radial")
# Make predictions on validation data
svm_predictions <- predict(svm_model, newdata = validation_data)
# Confusion matrix
confusion_matrix <- confusionMatrix(svm_predictions, validation_data$Category)
confusion_matrix
# Save metrics
# Extract the table from the confusion matrix
cm_table <- confusion_matrix$table
# Initialize vectors to store the metrics for each class
precision <- numeric(length = ncol(cm_table))
recall <- numeric(length = ncol(cm_table))
f1_score <- numeric(length = ncol(cm_table))
# Calculate metrics for each class
for (i in 1:ncol(cm_table)) {
tp <- cm_table[i, i]
fp <- sum(cm_table[, i]) - tp
fn <- sum(cm_table[i, ]) - tp
tn <- sum(cm_table) - tp - fp - fn
precision[i] <- tp / (tp + fp)
recall[i] <- tp / (tp + fn)
f1_score[i] <- 2 * (precision[i] * recall[i]) / (precision[i] + recall[i])
}
# Output the results
svm_metrics <- data.frame(
Class = colnames(cm_table),
Precision = precision,
Recall = recall,
F1_Score = f1_score
)
svm_metrics
```
SVMs with appropriate kernel functions can be effective in capturing non-linearity. We can address class imbalances by tuning kernel parameters. Although the math is complicated, SVMs have a reliable theoretical foundation and perform well in complex scenarios. The different kernel functions help to transform the input data into a higher-dimensional space, which can help capture more complex relationships. Here we use a RBF kernel which is highly versatile and effective for non-linear data. The RBF kernel can handle cases where the relationship between class labels and features is more complex. We use confusion matrix, kappa, precision, recall, and F1 score to evaluate the model as these metrics are useful for imbalanced datasets.
The SVM model has high precision, recall, and F1 scores reflecting the model's effectiveness in classifying Blood Donor cases. There is good model performance for the Hepatitis class. It is fairly accurate in predictions but has a slightly higher rate of false positives and negatives. It misclassifies 2 of out 24 as Blood Donor and 1 as Progressed HCV. There model misclassifies 2 out of 10 Progressed HCV as Hepatitis and 1 as Blood Donor. Although this could use some improvement by adding more data, it's better to have more false positives than to have false negatives. The strong kappa of 0.86 indicates a strong agreement between the model's predictions and the actual class labels.
## Model Comparison
```{r comparemodels}
xgb_df <- data.frame(Value = xgb_metrics)
rf_df <- data.frame(Value = rf_metrics)
svm_df <- data.frame(Value = svm_metrics)
# Add a new column 'Model' to each dataframe
xgb_df$Model <- "XGBoost"
rf_df$Model <- "randomForest"
svm_df$Model <- "SVM"
# Combine all three dataframes into one
combined_df <- rbind(xgb_df, rf_df, svm_df)
combined_df <- combined_df[, c("Model", "Value.Class", "Value.Precision", "Value.Recall", "Value.F1_Score")]
combined_df
```
In comparing the performance of the XGBoost, RandomForest, and SVM models across three classes, Blood Donor, Hepatitis, and Progressed HCV, some distinct patterns emerge. For the Blood Donor class, XGBoost shows exceptional performance with perfect precision and the highest F1 score (0.993), closely followed by RandomForest and SVM, both exhibiting high precision and F1 scores, though slightly lower. In the Hepatitis class, while all three models achieve the same precision, XGBoost and RandomForest outperform SVM in terms of recall and F1 score, both achieving an F1 score of 0.960 compared to SVM's 0.894. The most notable differences are observed in the Progressed HCV class, where XGBoost excels with a perfect recall and the highest F1 score (0.824), significantly outperforming both RandomForest and SVM, which show identical precision, recall, and F1 scores (0.700, 0.778, and 0.737, respectively). This indicates XGBoost's superior ability in correctly identifying all cases of Progressed HCV, a class where the other two models demonstrate comparatively weaker performance.
## Ensemble
```{r}
predictCategory <- function(new_data) {
# XGBoost Model (Multi-class)
# Train the model
xgb_model <- xgboost(data = train_matrix,
label = train_labels_numeric,
params = params,
nrounds = 50)
# Format validation data
validation_data_xg <- new_data
validation_labels_factor <- factor(validation_data_xg$Category, levels = c("0=Blood Donor", "1=Hepatitis", "Progressed HCV"))
validation_matrix <- as.matrix(validation_data_xg[, features])
# Make predictions on new data
xgb_predictions <- predict(xgb_model, newdata = as.matrix(validation_data_xg[, -which(names(validation_data_xg) == target_column)]))
# Convert predictions to class labels
max_prob_index <- matrix(xgb_predictions, ncol = length(unique(train_data_xg$Category)), byrow = TRUE)
xgb_class_labels <- apply(max_prob_index, 1, which.max) - 1
xgb_class_labels <- factor(xgb_class_labels, levels = 0:2, labels = c("0=Blood Donor", "1=Hepatitis", "Progressed HCV"))
# Random Forest Model (Multi-class)
set.seed(123)
bagged_model <- randomForest(Category ~ ., data = train_data)
rf_predictions <- predict(bagged_model, newdata = new_data)
# SVM model (Multi-class)
svm_model <- svm(Category ~ ., data = train_data, kernel = "radial")
svm_predictions <- predict(svm_model, newdata = new_data)
combined_predictions <- cbind(xgb_class_labels, rf_predictions, svm_predictions)
# Create a function to calculate majority vote
majority_vote <- function(row) {
# Count the occurrences of each class label in the row
counts <- table(row)
# Get the class label with the maximum count
majority_label <- as.character(as.numeric(names(counts)[which.max(counts)]))
return(majority_label)
}
# Apply the majority vote function to each row to get the final predictions
ensemble_predictions <- apply(combined_predictions, 1, majority_vote)
# Return the final prediction
return(ensemble_predictions)
}
```
Here we create an ensemble is the aggregate of the 3 models and uses majority vote to make predictions.
## Comparison of Ensemble to Individual Models
```{r compareensembles1, echo=TRUE, results='hide'}
predictions <- predictCategory(validation_data)
mapping <- c("0=Blood Donor", "1=Hepatitis", "Progressed HCV")
predictions <- mapping[as.integer(predictions)]
predictions <- as.factor(predictions)
```
```{r compareensembles2}
confusion_matrix <- confusionMatrix(predictions, validation_data$Category)
confusion_matrix$table
confusion_matrix$byClass
confusion_matrix$overall
```
The ensemble model demonstrates a robust performance across the three classes, Blood Donor, Hepatitis, and Progressed HCV, as indicated by various metrics such as sensitivity, specificity, precision, NPV, recall, and F1 score. The model shows high precision and perfect recall for the Hepatitis class, indicating a high accuracy in the model's prediciton of Hepatitis with very few false positives. The F1 score is also high (97.96%) achieving a high blance in classification accuracy. The ensemble also performs better compared to the individual models in predicting Progressed HCV, with a higher sensitivity, specificity, precision, recall, and F1 score than the individual models. Most of the model's predictions for Progressed HCV are reliable with 8/10 correctly classified as Progressed HCV, 1/10 misclassified as Blood Donor and 1/10 misclassified as Hepatitis. The kappa of the overall model is high (94.13%) indicating a strong agreement between the model's predictions and the actual classes. We could boost sensitivity and recall in the Progressed HCV class by obtaining more data for Fibrosis and Cirrhosis patients, using class weights, and oversampling the Progressed HCV class. Overall, the ensemble model is highly effective in identifity Blood Donor and Hepatitis classes, with slightly less but strong performance in the Progressed HCV class. The model's ability to maintain high precision and recall across classes is indicative of its robustness in this multi-class classification task.
### K-Cross Validation
```{r kcross}
balanced_data$Category <- factor(balanced_data$Category, levels = c("0=Blood Donor", "1=Hepatitis", "Progressed HCV"))
features <- setdiff(names(balanced_data), "Category")
# Number of folds
k <- 3
set.seed(123)
folds <- createFolds(balanced_data$Category, k = k, list = TRUE)
# List and matrices for metrics
results <- list()
precision_sum <- matrix(0, nrow = length(levels(balanced_data$Category)), ncol = k)
recall_sum <- matrix(0, nrow = length(levels(balanced_data$Category)), ncol = k)
f1_sum <- matrix(0, nrow = length(levels(balanced_data$Category)), ncol = k)
for(i in seq_along(folds)) {
# Split the data using stratified folds
trainingSet <- balanced_data[folds[[i]], ]
validationSet <- balanced_data[-folds[[i]], ]
# Train the ensemble model and make predictions
train_labels_numeric <- as.numeric(as.factor(trainingSet$Category)) - 1
train_matrix <- as.matrix(trainingSet[, features])
validation_matrix <- as.matrix(validationSet[, features])
# Train XGBoost Model
xgb_model <- xgboost(data = train_matrix,
label = train_labels_numeric,
params = params,
nrounds = 50)
# Make predictions on new data
xgb_predictions <- predict(xgb_model, newdata = validation_matrix)
# Convert predictions to class labels
max_prob_index <- matrix(xgb_predictions, ncol = length(unique(trainingSet$Category)), byrow = TRUE)
xgb_class_labels <- apply(max_prob_index, 1, which.max) - 1
xgb_class_labels <- factor(xgb_class_labels, levels = 0:2, labels = c("0=Blood Donor", "1=Hepatitis", "Progressed HCV"))
# Random Forest Model (Multi-class)
set.seed(123)
bagged_model <- randomForest(Category ~ ., data = trainingSet)
rf_predictions <- predict(bagged_model, newdata = validationSet)
# SVM model (Multi-class)
svm_model <- svm(Category ~ ., data = trainingSet, kernel = "radial")
svm_predictions <- predict(svm_model, newdata = validationSet)
# Combined predictions
combined_predictions <- cbind(xgb_class_labels, rf_predictions, svm_predictions)
majority_vote <- function(row) {
# Count the occurrences of each class label in the row
counts <- table(row)
# Get the class label with the maximum count
majority_label <- as.character(as.numeric(names(counts)[which.max(counts)]))
return(majority_label)
}
# Majority Vote
ensemble_predictions <- apply(combined_predictions, 1, majority_vote)
ensemble_predictions <- factor(ensemble_predictions, levels = c("1", "2", "3"), labels = c("0=Blood Donor", "1=Hepatitis", "Progressed HCV"))
# Confusion Matrix
cm <- confusionMatrix(ensemble_predictions, validationSet$Category)
results[[i]] <- cm
## Extracting per-class metrics
fold_metrics <- cm$byClass
class_levels <- levels(validationSet$Category)
for (j in 1:length(class_levels)) {
class_name <- paste("Class:", class_levels[j])
precision_sum[j, i] <- fold_metrics[class_name, "Precision"]
recall_sum[j, i] <- fold_metrics[class_name, "Recall"]
f1_sum[j, i] <- fold_metrics[class_name, "F1"]
}
}
# Calculate average performance across all folds
avg_performance <- lapply(results, function(x) x$overall)
avg_performance <- do.call("rbind", avg_performance)
# Calculate average metrics for each class
average_precision <- rowMeans(precision_sum, na.rm = TRUE)
average_recall <- rowMeans(recall_sum, na.rm = TRUE)
average_f1 <- rowMeans(f1_sum, na.rm = TRUE)
# Combine the averages into a data frame
avg_metrics <- data.frame(
Class = levels(balanced_data$Category),
Precision = average_precision,
Recall = average_recall,
F1_Score = average_f1
)
avg_metrics
colMeans(avg_performance)
```
Here we use stratified k-cross validation to further test our ensemble. K-cross validation partitions the dataset in k equally sized folds and for each iteration, the model is trained on "k-1" folds of the data. The remaining 1 fold is used as a test set to evaluate the model. This means that every data point gets to be the test set exactly once and in the training set "k-1" times. This reduces the bias that the model's performance estimate is dependent on the specific way the data is split. This is useful for imbalanced classes. The results show us a similar conclusion to the aforementioned evaluation, however, the recall and F1 score for Progressed HCV is lower, which we hypothesize is due to the class imbalance. More data for Fibrosis and Cirrhosis patients can improve this model.
## Using Ensemble to Predict New Data
```{r predictnewdata, echo=TRUE, results='hide'}
# Assuming new data has been transformed and cleaned
new_data <- data.frame(ALB = 3.8,
ALP = 3.5,
ALT = 3.2,
AST = 0.0007,
BIL = 1.8,
CHE = 10.12,
CHOL = 5.23,
CREA = 4.33,
GGT = 4.33,
Category = "")
# Prediction
prediction <- predictCategory(new_data)
prediction <- mapping[as.integer(prediction)]
```
```{r ensembleprediction}
prediction
```
Assuming that the new input data has been cleaned and transformed, the ensemble model predicts the new data to be HCV. I used lab values that were similar to a patient with HCV and the model correctly predicted the patient to have HCV.