diff --git a/middle_end/flambda/lifting/sort_lifted_constants.ml b/middle_end/flambda/lifting/sort_lifted_constants.ml index f79aa03a44fe..b42c9a0b7190 100644 --- a/middle_end/flambda/lifting/sort_lifted_constants.ml +++ b/middle_end/flambda/lifting/sort_lifted_constants.ml @@ -18,11 +18,43 @@ open! Simplify_import -module SCC_lifted_constants = Strongly_connected_components.Make (CIS) +(* Sort code IDs and symbols by linkage name because ids are unstable. In + particular, if we sort by integer id as usual, expect tests can fail + because the roundtrip to .flt can change ids, which then changes the order + of the bindings. *) +module CIS_by_name = struct + type t = Linkage_name.t * CIS.t + + include Identifiable.Make (struct + type nonrec t = t + + let compare (name1, _) (name2, _) = Linkage_name.compare name1 name2 + let equal (name1, _) (name2, _) = Linkage_name.equal name1 name2 + let print ppf (_name, cis) = CIS.print ppf cis + let output out (_name, cis) = CIS.output out cis + let hash (name, _) = Linkage_name.hash name + end) + + let code_id code_id : t = + Symbol.linkage_name (Code_id.code_symbol code_id), Code_id code_id + + let symbol symbol : t = + Symbol.linkage_name symbol, Symbol symbol + + let of_ (cis : CIS.t) = + match cis with + | Code_id c -> code_id c + | Symbol s -> symbol s + + let set_of_unordered_set set = + set |> CIS.Set.to_seq |> Seq.map of_ |> Set.of_seq +end + +module SCC_lifted_constants = Strongly_connected_components.Make (CIS_by_name) let build_dep_graph lifted_constants = (* Format.eprintf "SORTING:\n%!"; *) - LCS.fold lifted_constants ~init:(CIS.Map.empty, CIS.Map.empty) + LCS.fold lifted_constants ~init:(CIS_by_name.Map.empty, CIS_by_name.Map.empty) ~f:(fun (dep_graph, code_id_or_symbol_to_const) lifted_constant -> (* Format.eprintf "One constant: %a\n%!" LC.print lifted_constant; *) ListLabels.fold_left (LC.definitions lifted_constant) @@ -52,16 +84,20 @@ let build_dep_graph lifted_constants = let deps = CIS.Set.union (CIS.set_of_symbol_set free_syms) (CIS.set_of_code_id_set free_code_ids) + |> CIS_by_name.set_of_unordered_set in let being_defined = D.bound_symbols definition |> Bound_symbols.everything_being_defined + |> CIS_by_name.set_of_unordered_set in - CIS.Set.fold + CIS_by_name.Set.fold (fun being_defined (dep_graph, code_id_or_symbol_to_const) -> - let dep_graph = CIS.Map.add being_defined deps dep_graph in + let dep_graph = + CIS_by_name.Map.add being_defined deps dep_graph + in let code_id_or_symbol_to_const = - CIS.Map.add being_defined lifted_constant + CIS_by_name.Map.add being_defined lifted_constant code_id_or_symbol_to_const in dep_graph, code_id_or_symbol_to_const) @@ -78,7 +114,7 @@ let sort0 lifted_constants = in (* Format.eprintf "SCC graph is:@ %a\n%!" - (CIS.Map.print CIS.Set.print) + (CIS_by_name.Map.print CIS_by_name.Set.print) lifted_constants_dep_graph; *) let innermost_first = @@ -92,12 +128,12 @@ let sort0 lifted_constants = in let _, lifted_constants = ListLabels.fold_left code_id_or_symbols - ~init:(CIS.Set.empty, []) + ~init:(CIS_by_name.Set.empty, []) ~f:(fun ((already_seen, definitions) as acc) code_id_or_symbol -> - if CIS.Set.mem code_id_or_symbol already_seen then acc + if CIS_by_name.Set.mem code_id_or_symbol already_seen then acc else let lifted_constant = - CIS.Map.find code_id_or_symbol code_id_or_symbol_to_const + CIS_by_name.Map.find code_id_or_symbol code_id_or_symbol_to_const in let already_seen = (* We may encounter the same defining expression more @@ -105,8 +141,9 @@ let sort0 lifted_constants = may bind more than one symbol. We must avoid duplicates in the resulting [LC.t]. *) let bound_symbols = LC.bound_symbols lifted_constant in - CIS.Set.union - (Bound_symbols.everything_being_defined bound_symbols) + CIS_by_name.Set.union + (Bound_symbols.everything_being_defined bound_symbols + |> CIS_by_name.set_of_unordered_set) already_seen in already_seen, lifted_constant :: definitions)