diff --git a/plugins/qcheck-stm/src/config.ml b/plugins/qcheck-stm/src/config.ml index ac6fd680..f3b235ca 100644 --- a/plugins/qcheck-stm/src/config.ml +++ b/plugins/qcheck-stm/src/config.ml @@ -6,6 +6,8 @@ type t = { context : Context.t; sut_core_type : Ppxlib.core_type; init_sut : Ppxlib.expression; + include_ : string option; + protect_call : string option; } let get_sut_type_name config = @@ -97,6 +99,14 @@ let init path init_sut sut_str = let context = List.fold_left add context sigs in let* sut_core_type = sut_core_type sut_str and* init_sut = init_sut_from_string init_sut in - ok (sigs, { context; sut_core_type; init_sut }) + ok + ( sigs, + { + context; + sut_core_type; + init_sut; + include_ = None; + protect_call = None; + } ) with Gospel.Warnings.Error (l, k) -> error (Ortac_core.Warnings.GospelError k, l) diff --git a/plugins/qcheck-stm/src/ortac_qcheck_stm.ml b/plugins/qcheck-stm/src/ortac_qcheck_stm.ml index 345d4fd0..8fc44854 100644 --- a/plugins/qcheck-stm/src/ortac_qcheck_stm.ml +++ b/plugins/qcheck-stm/src/ortac_qcheck_stm.ml @@ -4,14 +4,15 @@ module Ir_of_gospel = Ir_of_gospel module Reserr = Reserr module Stm_of_ir = Stm_of_ir -let main path init sut include_ output quiet () = +let main path init sut include_ protect_call output quiet () = let open Reserr in let fmt = Registration.get_out_formatter output in let pp = pp quiet Ppxlib_ast.Pprintast.structure fmt in pp (let* sigs, config = Config.init path init sut in let* ir = Ir_of_gospel.run sigs config in - Stm_of_ir.stm include_ config ir) + let config = { config with include_; protect_call } in + Stm_of_ir.stm config ir) open Cmdliner @@ -38,6 +39,15 @@ end = struct system under test." ~docv:"INIT") + let protect_call = + Arg.( + value + & opt (some string) None + & info [ "p"; "protect-call" ] ~docv:"PROTECT_CALL" + ~doc: + "Protect the call of the QCheck tests with PROTECT_CALL. \ + PROTECT_CALL should be the name of a function.") + let term = let open Registration in Term.( @@ -46,6 +56,7 @@ end = struct $ init $ sut $ include_ + $ protect_call $ output_file $ quiet $ setup_log) diff --git a/plugins/qcheck-stm/src/stm_of_ir.ml b/plugins/qcheck-stm/src/stm_of_ir.ml index 80b56456..8ef17c49 100644 --- a/plugins/qcheck-stm/src/stm_of_ir.ml +++ b/plugins/qcheck-stm/src/stm_of_ir.ml @@ -770,7 +770,7 @@ let ghost_functions config = in aux config [] -let stm include_ config ir = +let stm config ir = let open Reserr in let* config, ghost_functions = ghost_functions config ir.ghost_functions in let warn = [%stri [@@@ocaml.warning "-26-27"]] in @@ -783,7 +783,7 @@ let stm include_ config ir = |> Mod.ident |> Incl.mk |> pstr_include) - include_ + config.include_ |> Option.to_list in let sut = sut_type config in @@ -841,11 +841,20 @@ let stm include_ config ir = let call_tests = let loc = Location.none in let descr = estring (module_name ^ " STM tests") in - [%stri - let _ = + let expr = + [%expr QCheck_base_runner.run_tests_main (let count = 1000 in [ STMTests.agree_test ~count ~name:[%e descr] ])] + in + let expr = + match config.protect_call with + | None -> expr + | Some f -> + pexp_apply (qualify [ "Spec" ] f) + [ (Nolabel, efun [ (Nolabel, punit) ] expr) ] + in + pstr_value Nonrecursive [ value_binding ~pat:ppat_any ~expr ] in ok ([ open_mod module_name ]