Skip to content

Commit

Permalink
Refactoring and optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
SGrondin committed Jan 7, 2023
1 parent e451f42 commit 5ad1a4e
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 76 deletions.
11 changes: 7 additions & 4 deletions bin/bench.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Size: %d
-----|s}
x.count x.size x.cumulates_count x.compress_count x.auto_compress_count

let snapshot sf t0 t1 =
print_endline (sprintf sf Int63.((t1 - t0) / of_int 1_000 |> to_string_hum ~delimiter:','))

let () =
let arr = Array.init 2_000_000 ~f:(fun _i -> Random.float 1.) in
let td = Tdigest.create () in
Expand All @@ -25,20 +28,20 @@ let () =
let t2 = Time_now.nanoseconds_since_unix_epoch () in
let info2 = Tdigest.info td in

print_endline (sprintf "Add 100k: %sms" Int63.((t1 - t0) / of_int 1_000_000 |> to_string));
snapshot "Add 100k: %s us" t0 t1;
print_endline (info_to_string info1);
print_endline (sprintf "Compress: %sms" Int63.((t2 - t1) / of_int 1_000_000 |> to_string));
snapshot "Compress: %s us" t1 t2;
print_endline (info_to_string info2);

let t3 = Time_now.nanoseconds_since_unix_epoch () in
let _td, str = Tdigest.to_string td in
let t4 = Time_now.nanoseconds_since_unix_epoch () in
print_endline (sprintf "Serialized into %d bytes" (String.length str));
print_endline (sprintf "Serialization: %sms" Int63.((t4 - t3) / of_int 1_000_000 |> to_string));
snapshot "Serialization: %s us" t3 t4;

let t5 = Time_now.nanoseconds_since_unix_epoch () in
let td = Tdigest.of_string (sprintf "%s%s" str str) in
let t6 = Time_now.nanoseconds_since_unix_epoch () in
print_endline (sprintf "Parsing: %sms" Int63.((t6 - t5) / of_int 1_000_000 |> to_string));
snapshot "Parsing: %s us" t5 t6;
let info3 = Tdigest.info td in
print_endline (info_to_string info3)
123 changes: 55 additions & 68 deletions src/tdigest.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type settings = {
delta: delta;
k: k;
cx: cx;
k_delta: float option;
}
[@@deriving sexp]

Expand Down Expand Up @@ -69,13 +70,6 @@ type info = {
}
[@@deriving sexp]

type bounds =
| Neither
| Both of centroid * centroid
| Equal of centroid
| Lower of centroid
| Upper of centroid

let get_min = function
| { min = Some _ as x; _ } -> x
| { min = None; n = 0.0; _ } -> None
Expand All @@ -92,6 +86,10 @@ let get_max = function
td.max <- max;
max

let get_k_delta = function
| Automatic k, Merging delta -> Some (k / delta)
| _ -> None

let create ?(delta = default_delta) ?(k = default_k) ?(cx = default_cx) () =
let k =
match k with
Expand All @@ -113,7 +111,7 @@ let create ?(delta = default_delta) ?(k = default_k) ?(cx = default_cx) () =
| Growth x -> invalid_argf "TDigest cx parameter must be positive, but was %f" x ()
in
{
settings = { delta; k; cx };
settings = { delta; k; cx; k_delta = get_k_delta (k, delta) };
centroids = Map.empty;
min = None;
max = None;
Expand Down Expand Up @@ -222,39 +220,29 @@ let internal_digest td ~n ~mean =
in
cumulate td ~exact:false

let shuffle_inplace arr =
let _i =
Array.fold_right arr ~init:(Array.length arr) ~f:(fun _x i ->
let random = Random.float (of_int i) |> to_int in
let current = pred i in
Array.swap arr random current;
current)
in
()

let weights_of_td = function
(* n is out of sync, must check centroids *)
| { centroids; _ } when Map.is_empty centroids -> [||]
| { centroids; _ } ->
let arr = Array.create ~len:(Map.length centroids) empty_centroid in
let _i =
Map.fold centroids ~init:0 ~f:(fun ~key:_ ~data i ->
Array.set arr i data;
arr.(i) <- data;
succ i)
in
arr

let weights_of_map ~len map =
let arr = Array.create ~len empty_centroid in
let weights_of_table table =
let arr = Array.create ~len:(Table.length table) empty_centroid in
let _i =
Map.fold map ~init:0 ~f:(fun ~key:mean ~data:n i ->
Array.set arr i { empty_centroid with mean; n };
Table.fold table ~init:0 ~f:(fun ~key:mean ~data:n i ->
arr.(i) <- { empty_centroid with mean; n };
succ i)
in
arr

let rebuild ~auto settings (stats : stats) arr =
shuffle_inplace arr;
Array.permute arr;
let blank =
{
settings;
Expand All @@ -275,12 +263,16 @@ let rebuild ~auto settings (stats : stats) arr =
cumulate td ~exact:true

let digest ?(n = 1) td ~mean =
let td = internal_digest td ~n:(Int.to_float n) ~mean in
let td = internal_digest td ~n:(of_int n) ~mean in
match td.settings with
| { delta = Merging delta; k = Automatic k; _ } when Map.length td.centroids |> of_int > k / delta ->
| { k_delta = Some kd; _ } when Map.length td.centroids |> of_int > kd ->
rebuild ~auto:true td.settings td.stats (weights_of_td td)
| _ -> td

let add ?(n = 1) ~data td = digest td ~n ~mean:data

let add_list ?(n = 1) xs td = List.fold xs ~init:td ~f:(fun acc mean -> digest acc ~n ~mean)

let compress ?delta td =
match delta with
| None -> rebuild ~auto:false td.settings td.stats (weights_of_td td)
Expand All @@ -289,10 +281,6 @@ let compress ?delta td =
let updated = rebuild ~auto:false { td.settings with delta } td.stats (weights_of_td td) in
{ updated with settings }

let add ?(n = 1) ~data td = digest td ~n ~mean:data

let add_list ?(n = 1) xs td = List.fold xs ~init:td ~f:(fun acc mean -> digest acc ~n ~mean)

let to_string td =
let buf = Bytes.create (Map.length td.centroids |> Int.( * ) 16) in
let add_float pos ~data:f =
Expand All @@ -311,49 +299,48 @@ let to_string td =
in
td, Bytes.unsafe_to_string ~no_mutation_while_string_reachable:buf

let parse_float str pos =
let open Int64 in
let next off = String.get str Int.(pos + off) |> Char.to_int |> of_int_exn in
(String.get str pos |> Char.to_int |> of_int_exn)
lor shift_left (next 1) 8
lor shift_left (next 2) 16
lor shift_left (next 3) 24
lor shift_left (next 4) 32
lor shift_left (next 5) 40
lor shift_left (next 6) 48
lor shift_left (next 7) 56
|> float_of_bits

let of_string ?(delta = default_delta) ?(k = default_k) ?(cx = default_cx) str =
if Int.(String.length str % 16 <> 0) then invalid_arg "Invalid string length for Tdigest.of_string";
let settings = { delta; k; cx } in
let _i, _mean, _n, map =
String.fold str ~init:(0, 0L, 0L, Map.empty) ~f:(fun (i, pmean, pn, acc) c ->
let x = c |> Char.to_int |> Int64.of_int_exn in
match i with
| 0
|1
|2
|3
|4
|5
|6
|7 ->
let mean = Int64.(pmean lor shift_left x Int.(i * 8)) in
succ i, mean, pn, acc
| 8
|9
|10
|11
|12
|13
|14 ->
let n = Int64.(pn lor shift_left x Int.((i - 8) * 8)) in
succ i, pmean, n, acc
| 15 ->
let mean = Int64.float_of_bits pmean in
let n = Int64.(pn lor shift_left x 56 |> float_of_bits) in
let acc = Map.update acc mean ~f:(Option.value_map ~default:n ~f:(( + ) n)) in
0, 0L, 0L, acc
| x -> failwithf "Tdigest.of_string: impossible case '%d'. Please report this bug." x ())
let settings = { delta; k; cx; k_delta = get_k_delta (k, delta) } in
let table = Table.create () in
let rec loop = function
| pos when Int.(pos = String.length str) -> ()
| pos ->
let mean = parse_float str pos in
let n = parse_float str Int.(pos + 8) in
Table.update table mean ~f:(Option.value_map ~default:n ~f:(( + ) n));
(loop [@tailcall]) Int.(pos + 16)
in
weights_of_map ~len:Int.(String.length str / 16) map |> rebuild ~auto:true settings empty_stats
loop 0;
weights_of_table table |> rebuild ~auto:true settings empty_stats

let merge ?(delta = default_delta) ?(k = default_k) ?(cx = default_cx) tds =
let settings = { delta; k; cx } in
let map =
List.fold tds ~init:Map.empty ~f:(fun acc { centroids; _ } ->
Map.fold centroids ~init:acc ~f:(fun ~key:_ ~data:{ mean; n; _ } acc ->
Map.update acc mean ~f:(Option.value_map ~default:n ~f:(( + ) n))))
in
weights_of_map ~len:(Map.length map) map |> rebuild ~auto:true settings empty_stats
let settings = { delta; k; cx; k_delta = get_k_delta (k, delta) } in
let table = Table.create () in
List.iter tds ~f:(fun { centroids; _ } ->
Map.iter centroids ~f:(fun { mean; n; _ } ->
Table.update table mean ~f:(Option.value_map ~default:n ~f:(( + ) n))));
weights_of_table table |> rebuild ~auto:true settings empty_stats

type bounds =
| Neither
| Both of centroid * centroid
| Equal of centroid
| Lower of centroid
| Upper of centroid

let bounds td needle lens =
let gt = ref None in
Expand Down
8 changes: 4 additions & 4 deletions test/test_tdigest.ml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ let%expect_test "compress" =
check_min_max td;
[%expect {|
100
46
45
(((0 1)) ((990 1))) |}]
in

Expand All @@ -75,7 +75,7 @@ let%expect_test "compress" =
check_size td;
check_min_max td;
[%expect {|
2128
2156
(((0 1)) ((99990 1))) |}]
in
()
Expand Down Expand Up @@ -151,7 +151,7 @@ let%expect_test "percentile ranks" =
in
Float.to_string max_err |> print_endline;
(* must be < 0.01 *)
[%expect {| 0.0020868899459022261 |}]
[%expect {| 0.0020962038264262794 |}]
in

(* from an exact match *)
Expand Down Expand Up @@ -207,7 +207,7 @@ let%expect_test "percentiles" =
in
Float.to_string max_err |> print_endline;
(* must be < 0.01 *)
[%expect {| 0.0020868899459022261 |}]
[%expect {| 0.0020962038264262794 |}]
in
()

Expand Down

0 comments on commit 5ad1a4e

Please sign in to comment.