Skip to content

Commit

Permalink
Get rid of GlobalContextMessage.tactics
Browse files Browse the repository at this point in the history
  • Loading branch information
LasseBlaauwbroek committed Nov 8, 2024
1 parent 2c74b6a commit 301cd02
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 69 deletions.
22 changes: 14 additions & 8 deletions pytact/fake_python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,25 @@ async def text_prediction_loop(context : GlobalContextMessage):
else:
raise Exception("Capnp protocol error")

async def graph_prediction_loop(context : GlobalContextMessage, level):
async def graph_prediction_loop(context : GlobalContextMessage, prev_tactics, level):
print(f"level {level}")
for cluster in context.definitions.clustered_definitions(full = False):
print('cluster:')
for d in cluster:
print(f' {d.name}')
for t in context.tactics:
print(t)

tactics = prev_tactics.copy()
for d in context.definitions.definitions(full = False):
if p := d.proof:
for ps in p:
if t := ps.tactic:
tactics.add((t.ident, len(ps.outcomes[0].tactic_arguments)))

print(context.log_annotation)
prediction_requests = context.prediction_requests
cool_definitions = [ d.node for d in context.definitions.definitions() if d.name == "Coq.Init.Logic.I" ]
zeroArgs = [t.ident for t in context.tactics if t.parameters == 0]
oneArg = [t.ident for t in context.tactics if t.parameters == 1]
zeroArgs = [ident for (ident, parameters) in tactics if parameters == 0]
oneArg = [ident for (ident, parameters) in tactics if parameters == 1]
async for msg in prediction_requests:
# Redirect any exceptions to Coq. Additionally, deal with CancellationError
# thrown when a request from Coq is cancelled
Expand All @@ -61,11 +67,11 @@ async def graph_prediction_loop(context : GlobalContextMessage, level):
await prediction_requests.asend(TacticPredictionsGraph(preds))
elif isinstance(msg, CheckAlignmentMessage):
unknown_definitions = list(context.definitions.definitions())
unknown_tactics = [t.ident for t in context.tactics]
unknown_tactics = [ident for (ident, _) in tactics]
alignment = CheckAlignmentResponse(unknown_definitions, unknown_tactics)
await prediction_requests.asend(alignment)
elif isinstance(msg, GlobalContextMessage):
await graph_prediction_loop(msg, level + 1)
await graph_prediction_loop(msg, tactics, level + 1)
else:
raise Exception(f"Capnp protocol error {msg}")

Expand All @@ -80,7 +86,7 @@ async def run_session(args, record_file, capnp_stream):
await text_prediction_loop(messages_generator)
elif args.mode == 'graph':
print('Python server running in graph mode')
await graph_prediction_loop(messages_generator, 0)
await graph_prediction_loop(messages_generator, set(), 0)
else:
raise Exception("The 'mode' argument needs to be either 'text' or 'graph'")

Expand Down
115 changes: 54 additions & 61 deletions src/neural_learner.ml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ let find_tactic tacs id =
| None -> raise NoSuchTactic
| Some x -> x

let add_tactic_info env map tac =
let { base_tactic; args; _ } = analyze_tactic tac in
let params = List.length args in
TacticMap.add
(Tactic_hash.tactic_hash env base_tactic) (base_tactic, params) map

let find_local_argument context_map =
let context_map_inv =
Names.Id.Map.fold (fun id (_, node) m -> Int.Map.add node (Tactic_one_variable.TVar id) m)
Expand Down Expand Up @@ -193,18 +199,19 @@ type context_state =
; id : int
; constants : Environ.constant_key Cmap_env.t
; inductives : Environ.mind_key Mindmap_env.t
; section : Constr.named_context }
; section : Constr.named_context
; tactics : (glob_tactic_expr * int) TacticMap.t }
type context_stack =
{ stack : context_state list
; stack_size : int }

let update_context_stack id tacs env_extra env { stack_size; stack } =
let state, old_constants, old_inducives, old_section = match stack with
let update_context_stack ?(force=false) id env_extra env { stack_size; stack } =
let state, old_constants, old_inducives, old_section, tactics = match stack with
| [] ->
let (empty_state, ()), _ = CICGraphMonad.run_empty (CICGraphMonad.return ())
(G.HashMap.create 0) G.builder_nil 0 in
empty_state, Cmap_env.empty, Mindmap_env.empty, []
| { state; constants; inductives; section; _ }::_ -> state, constants, inductives, section in
empty_state, Cmap_env.empty, Mindmap_env.empty, [], TacticMap.empty
| { state; constants; inductives; section; tactics; _ }::_ -> state, constants, inductives, section, tactics in

let globals = Environ.Globals.view Environ.(env.env_globals) in
let section = Environ.named_context env in
Expand Down Expand Up @@ -236,8 +243,23 @@ let update_context_stack id tacs env_extra env { stack_size; stack } =
++ pr_vertical_list Id.print (Id.Set.elements new_section)
);

if Cset.is_empty new_constants && Mindset.is_empty new_inductives && Id.Set.is_empty new_section &&
TacticMap.is_empty tacs then state, { stack_size; stack } else
if (not force) && Cset.is_empty new_constants && Mindset.is_empty new_inductives && Id.Set.is_empty new_section
then state, tactics, { stack_size; stack } else

let update_tactics fold find tmap set map =
fold (fun c tmap ->
match find c map with
| None -> tmap
| Some ls ->
List.fold_left (fun tmap (_, t) ->
match t with
| None -> tmap
| Some t -> add_tactic_info env tmap t
) tmap ls
) set tmap in
let env_vars, env_const = env_extra in
let tactics = update_tactics Cset.fold Cmap.find_opt tactics new_constants env_const in
let tactics = update_tactics Id.Set.fold Id.Map.find_opt tactics new_section env_vars in

let { def_count; node_count; edge_count; defs; nodes; edges }, state =
let open Monad_util.WithMonadNotations(CICGraphMonad) in
Expand All @@ -264,12 +286,6 @@ let update_context_stack id tacs env_extra env { stack_size; stack } =
GlobalContextAddition.log_annotation_set init @@ log_annotation ();
ignore(GlobalContextAddition.data_version_set_reader init Api.Reader.current_version);
GlobalContextAddition.stack_size_set_int_exn init stack_size;
let tac_arr = GlobalContextAddition.tactics_init init @@ TacticMap.cardinal tacs in
List.iteri (fun i (hash, (_tac, params)) ->
let arri = Capnp.Array.get tac_arr i in
Api.Builder.AbstractTactic.ident_set arri hash;
Api.Builder.AbstractTactic.parameters_set_exn arri params)
(TacticMap.bindings tacs);
W.write_graph
~node_hash ~node_label ~node_lower:(fun n -> fst @@ G.lower n)
~node_dep_index:(fun (stack_id, _) -> stack_size - stack_id) ~node_local_index
Expand All @@ -284,13 +300,14 @@ let update_context_stack id tacs env_extra env { stack_size; stack } =
let state = { state with
previous = None
; external_previous = Option.cata (fun p -> [p]) state.external_previous state.previous } in
state, { stack_size = stack_size + 1
; stack = { request = builder
; state; id
; constants = globals.constants
; inductives = globals.inductives
; section }
::stack }
state, tactics, { stack_size = stack_size + 1
; stack = { request = builder
; state; id
; constants = globals.constants
; inductives = globals.inductives
; section
; tactics }
::stack }

let context_stack = Summary.ref ~name:"neural-learner-graph-cache"
{ stack = []; stack_size = 0 }
Expand All @@ -300,13 +317,13 @@ let sync_context_stack add_global_context =
let id = ref 0 in
let remote_state = ref [] in
let remote_stack_size = ref 0 in
fun ?(keep_cache=true) tacs env_extra env ->
fun ?(keep_cache=true) ?(force=false) env_extra env ->
if debug_option () then
Feedback.msg_notice Pp.(
str "old remote stack : " ++ prlist_with_sep (fun () -> str "-") int !remote_state ++ fnl () ++
str "old local stack : " ++ prlist_with_sep (fun () -> str "-")
(fun { id; _ } -> int id) !context_stack.stack);
let state, ({ stack_size; stack } as cache) = update_context_stack !id tacs env_extra env !context_stack in
let state, tactics, ({ stack_size; stack } as cache) = update_context_stack ~force !id env_extra env !context_stack in
if keep_cache then
context_stack := cache;
if debug_option () then
Expand All @@ -333,7 +350,7 @@ let sync_context_stack add_global_context =
remote_stack_size := stack_size;
if debug_option () then
Feedback.msg_notice Pp.(str "new remote stack : " ++ prlist_with_sep (fun () -> str "-") int !remote_state);
state, stack_size
state, tactics, stack_size

type capnp_connection =
{ rc : Unix.file_descr Capnp_unix.IO.ReadContext.t
Expand Down Expand Up @@ -383,8 +400,8 @@ type connection =

type communicator =
{ add_global_context : (Api.Builder.GlobalContextAddition.t -> unit) -> unit
; sync_context_stack : ?keep_cache:bool -> (glob_tactic_expr * location) TacticMap.t -> env_extra -> Environ.env ->
CICGraphMonad.state * int
; sync_context_stack : ?keep_cache:bool -> ?force:bool -> env_extra -> Environ.env ->
CICGraphMonad.state * (glob_tactic_expr * int) TacticMap.t * int
; request_prediction : (Api.Builder.PredictionRequest.t -> unit) ->
(Graph_api.ro, Api.Reader.Prediction.t, Api.Reader.array_t) Capnp.Array.t
; request_text_prediction : (Api.Builder.PredictionRequest.t -> unit) ->
Expand Down Expand Up @@ -675,29 +692,14 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc
preds

type model =
{ tactics : (glob_tactic_expr * int) TacticMap.t
; proofs : env_extra }
{ proofs : env_extra }

let last_model = Summary.ref ~name:"neural-learner-lastmodel" { tactics = TacticMap.empty
; proofs = Id.Map.empty, Cmap.empty }
let last_model = Summary.ref ~name:"neural-learner-lastmodel" { proofs = Id.Map.empty, Cmap.empty }

let empty () =
{ tactics = TacticMap.empty; proofs = Id.Map.empty, Cmap.empty }

let add_tactic_info env map tac =
let { base_tactic; args; _ } = analyze_tactic tac in
let params = List.length args in
if params >= 256 then map else
TacticMap.add
(Tactic_hash.tactic_hash env base_tactic) (base_tactic, params) map

let learn { tactics; proofs = var_proofs, const_proofs } (kn, path, status) outcomes tac =
let tactics = match tac with
| None -> tactics
| Some tac ->
let tac = tactic_repr tac in
let tactics = add_tactic_info (Global.env ()) tactics tac in
tactics in
{ proofs = Id.Map.empty, Cmap.empty }

let learn { proofs = var_proofs, const_proofs } (kn, path, status) outcomes tac =

(* TODO: Filtering out bad proof states:
Occasionally, proof states refer to section variables that have been filtered out by Coq during section
Expand Down Expand Up @@ -733,14 +735,6 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc
(TS.proof_state_hypotheses ps) ~init:status in
status in

(* TODO: Ridiculous tactic filtering: *)
let tac = match tac with
| None -> None
| Some tac ->
let { base_tactic; args; _ } = analyze_tactic @@ tactic_repr tac in
let params = List.length args in
if params >= 256 then None else Some tac in

(* TODO: Drop-in shadowing replacement for mk_outcome. For now, we don't need the proof term and after
states. We butcher them to make the payload smaller and faster to compute. *)
let mk_outcome before result =
Expand All @@ -758,21 +752,21 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc
(* TODO: This is not entirely correct:
We always attach a proof to a constant, never to a section variable. No good solution for now. *)
let const_proofs = Cmap.update constant update const_proofs in
let db = { tactics; proofs = var_proofs, const_proofs } in
let db = { proofs = var_proofs, const_proofs } in
last_model := db;
db

let predict { tactics; proofs } =
let predict { proofs } =
let { add_global_context; sync_context_stack
; request_prediction; request_text_prediction; _ } = get_communicator () in
let env = Global.env () in
if not @@ textmode_option () then
let state, stack_size =
sync_context_stack ~keep_cache:false tactics proofs env in
let state, tacs, stack_size =
sync_context_stack ~keep_cache:false ~force:true proofs env in
let find_global_argument = find_global_argument state in
fun f ->
if f = [] then IStream.empty else
let preds = predict request_prediction find_global_argument stack_size state tactics env
let preds = predict request_prediction find_global_argument stack_size state tacs env
(List.hd f).state in
let preds = List.map (fun (t, c) -> { confidence = c; focus = 0; tactic = tactic_make t }) preds in
IStream.of_list preds
Expand All @@ -792,9 +786,8 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc
let module Request = Api.Builder.PredictionProtocol.Request in
let module Response = Api.Reader.PredictionProtocol.Response in
let env = Global.env () in
let { tactics; proofs } = !last_model in
let proofs = (!last_model).proofs in
let state, stack_size = sync_context_stack ~keep_cache:false tactics proofs env in
let state, tactics, stack_size = sync_context_stack ~keep_cache:false proofs env in
let request = Request.init_root () in
Request.check_alignment_set request;
let unaligned_tacs, unaligned_defs = check_alignment () in
Expand Down Expand Up @@ -836,7 +829,7 @@ module NeuralLearner : TacticianOnlineLearnerType = functor (TS : TacticianStruc
(* We don't send the list of tactics, hence the empty list. Tactics are only sent right before
prediction requests are made. *)
let proofs = (!last_model).proofs in
let _, stack_size = sync_context_stack TacticMap.empty proofs (Global.env ()) in
let _, _, stack_size = sync_context_stack proofs (Global.env ()) in
if debug_option () then
Feedback.msg_notice Pp.(str "Cache stack size: " ++ int stack_size)

Expand Down

0 comments on commit 301cd02

Please sign in to comment.