Skip to content

Sort mutually-recursive constants by linkage name, not id #315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: flambda2.0-stable
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions middle_end/flambda/lifting/sort_lifted_constants.ml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,43 @@

open! Simplify_import

module SCC_lifted_constants = Strongly_connected_components.Make (CIS)
(* Sort code IDs and symbols by linkage name because ids are unstable. In
particular, if we sort by integer id as usual, expect tests can fail
because the roundtrip to .flt can change ids, which then changes the order
of the bindings. *)
module CIS_by_name = struct
type t = Linkage_name.t * CIS.t

include Identifiable.Make (struct
type nonrec t = t

let compare (name1, _) (name2, _) = Linkage_name.compare name1 name2
let equal (name1, _) (name2, _) = Linkage_name.equal name1 name2
let print ppf (_name, cis) = CIS.print ppf cis
let output out (_name, cis) = CIS.output out cis
let hash (name, _) = Linkage_name.hash name
end)

let code_id code_id : t =
Symbol.linkage_name (Code_id.code_symbol code_id), Code_id code_id

let symbol symbol : t =
Symbol.linkage_name symbol, Symbol symbol

let of_ (cis : CIS.t) =
match cis with
| Code_id c -> code_id c
| Symbol s -> symbol s

let set_of_unordered_set set =
set |> CIS.Set.to_seq |> Seq.map of_ |> Set.of_seq
end

module SCC_lifted_constants = Strongly_connected_components.Make (CIS_by_name)

let build_dep_graph lifted_constants =
(* Format.eprintf "SORTING:\n%!"; *)
LCS.fold lifted_constants ~init:(CIS.Map.empty, CIS.Map.empty)
LCS.fold lifted_constants ~init:(CIS_by_name.Map.empty, CIS_by_name.Map.empty)
~f:(fun (dep_graph, code_id_or_symbol_to_const) lifted_constant ->
(* Format.eprintf "One constant: %a\n%!" LC.print lifted_constant; *)
ListLabels.fold_left (LC.definitions lifted_constant)
Expand Down Expand Up @@ -52,16 +84,20 @@ let build_dep_graph lifted_constants =
let deps =
CIS.Set.union (CIS.set_of_symbol_set free_syms)
(CIS.set_of_code_id_set free_code_ids)
|> CIS_by_name.set_of_unordered_set
in
let being_defined =
D.bound_symbols definition
|> Bound_symbols.everything_being_defined
|> CIS_by_name.set_of_unordered_set
in
CIS.Set.fold
CIS_by_name.Set.fold
(fun being_defined (dep_graph, code_id_or_symbol_to_const) ->
let dep_graph = CIS.Map.add being_defined deps dep_graph in
let dep_graph =
CIS_by_name.Map.add being_defined deps dep_graph
in
let code_id_or_symbol_to_const =
CIS.Map.add being_defined lifted_constant
CIS_by_name.Map.add being_defined lifted_constant
code_id_or_symbol_to_const
in
dep_graph, code_id_or_symbol_to_const)
Expand All @@ -78,7 +114,7 @@ let sort0 lifted_constants =
in
(*
Format.eprintf "SCC graph is:@ %a\n%!"
(CIS.Map.print CIS.Set.print)
(CIS_by_name.Map.print CIS_by_name.Set.print)
lifted_constants_dep_graph;
*)
let innermost_first =
Expand All @@ -92,21 +128,22 @@ let sort0 lifted_constants =
in
let _, lifted_constants =
ListLabels.fold_left code_id_or_symbols
~init:(CIS.Set.empty, [])
~init:(CIS_by_name.Set.empty, [])
~f:(fun ((already_seen, definitions) as acc) code_id_or_symbol ->
if CIS.Set.mem code_id_or_symbol already_seen then acc
if CIS_by_name.Set.mem code_id_or_symbol already_seen then acc
else
let lifted_constant =
CIS.Map.find code_id_or_symbol code_id_or_symbol_to_const
CIS_by_name.Map.find code_id_or_symbol code_id_or_symbol_to_const
in
let already_seen =
(* We may encounter the same defining expression more
than once, in the case of sets of closures, which
may bind more than one symbol. We must avoid
duplicates in the resulting [LC.t]. *)
let bound_symbols = LC.bound_symbols lifted_constant in
CIS.Set.union
(Bound_symbols.everything_being_defined bound_symbols)
CIS_by_name.Set.union
(Bound_symbols.everything_being_defined bound_symbols
|> CIS_by_name.set_of_unordered_set)
already_seen
in
already_seen, lifted_constant :: definitions)
Expand Down