-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathknn.ml
55 lines (48 loc) · 1.53 KB
/
knn.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
let k = ref 5
let set_k new_k =
k := new_k;
!k
(* Return the most common class among the neighbor samples *)
let vote labels =
let uniq_labels = List.sort_uniq compare labels in
let counts =
List.map
(fun x -> (x, List.length (List.filter (fun a -> a = x) labels)))
uniq_labels
in
let sorted_counts =
List.sort (fun a b -> compare (snd a) (snd b)) counts |> List.rev
in
List.nth sorted_counts 0 |> fst
(* computes the euclidean distance of x and y *)
let euclidean_distance x y =
Float.sqrt
(List.fold_left2
(fun init a b -> ((b -. a) *. (b -. a)) +. init)
0. x y)
(* argsort implementation *)
let arg_sort lst num =
let arg_tuple = List.mapi (fun i elt -> (i, elt)) lst in
let sorted_tuple =
List.sort (fun a b -> compare (snd a) (snd b)) arg_tuple
in
let arg_lst = fst (List.split sorted_tuple) in
List.filteri (fun i elt -> i < num) arg_lst
let predict x_test x_train y_train =
let y_pred = ref [] in
List.iteri
(fun i tst_elt ->
let distances =
List.map (fun x -> euclidean_distance x tst_elt) x_train
in
let idx = arg_sort distances !k in
let knn = List.map (fun x -> List.nth y_train x) idx in
y_pred := vote knn :: !y_pred)
x_test;
List.rev !y_pred
let fit_and_predict x_train y_train x_test y_test =
k := List.length x_train |> float_of_int |> Float.sqrt |> int_of_float;
let y_p = predict x_test x_train y_train in
let acc = Utils.accuracy y_test y_p in
let mse = Utils.mean_squared_error y_test y_p in
(acc, mse, y_p)