diff --git a/jsx/brisk_ppx.ml b/jsx/brisk_ppx.ml index e37de7e..c3a21c4 100644 --- a/jsx/brisk_ppx.ml +++ b/jsx/brisk_ppx.ml @@ -1,6 +1,7 @@ module P = Ppxlib.Ast module ATH = Ppxlib.Ast_helper module Ast_builder = Ppxlib.Ast_builder.Default +module Ast_pattern = Ppxlib.Ast_pattern let component_ident ~loc = Ast_builder.(pexp_ident ~loc (Located.lident ~loc "brisk-component")) @@ -35,7 +36,7 @@ module JSX_ppx = struct ATH.Exp.apply ~loc ~attrs (component_ident ~loc) args let is_jsx = - let open Ppxlib.Ast_pattern in + let open Ast_pattern in let jsx_attr = attribute ~name:(string "JSX") ~payload:__ in fun attr -> parse jsx_attr Ppxlib.Location.none @@ -77,7 +78,7 @@ end module Declaration_ppx = struct let func_pattern = - Ppxlib.Ast_pattern.( + Ast_pattern.( alt ( pexp_fun __ __ __ __ |> map ~f:(fun f lbl opt_arg pat expr -> @@ -86,7 +87,7 @@ module Declaration_ppx = struct |> map ~f:(fun f ident expr -> f (`Newtype (ident, expr))) )) let match_ pattern ?on_error loc ast_node ~with_ = - Ppxlib.Ast_pattern.parse pattern ?on_error loc ast_node with_ + Ast_pattern.parse pattern ?on_error loc ast_node with_ let attribute_name = function | `Component -> "component" @@ -123,19 +124,27 @@ module Declaration_ppx = struct | `Native -> [%expr Brisk_reconciler.Expert.nativeComponent] | `Component -> [%expr Brisk_reconciler.Expert.component] in - [%expr - let [%p component_ident_pattern ~loc] = - [%e create_component_expr] - ~useDynamicKey:[%e Ast_builder.(ebool ~loc useDynamicKey)] - [%e component_name] - in - fun ?(key = Brisk_reconciler.Key.none) -> - [%e map_component_expression expr]] + let fun_expr expr = + [%expr + let [%p component_ident_pattern ~loc] = + [%e create_component_expr] + ~useDynamicKey:[%e Ast_builder.(ebool ~loc useDynamicKey)] + [%e component_name] + in + fun ?(key = Brisk_reconciler.Key.none) -> + [%e map_component_expression expr]] + in + match_ + Ast_pattern.(pexp_constraint __ __) + loc expr + ~with_:(fun expr core_type -> + Ast_builder.(pexp_constraint ~loc (fun_expr expr) core_type)) + ~on_error:(fun () -> fun_expr expr) let declare_attribute ctx typ = let open Ppxlib.Attribute in declare (attribute_name typ) ctx - Ppxlib.Ast_pattern.( + Ast_pattern.( alt_option (single_expr_payload (pexp_ident (lident __'))) (pstr nil)) (function | Some { txt = "useDynamicKey" } -> true @@ -187,7 +196,7 @@ module Declaration_ppx = struct in let transform ~useDynamicKey attribute value_binding = let value_binding_loc = value_binding.P.pvb_loc in - Ppxlib.Ast_pattern.(parse (value_binding ~pat:(ppat_var __) ~expr:__)) + Ast_pattern.(parse (value_binding ~pat:(ppat_var __) ~expr:__)) value_binding_loc value_binding (fun var_pat expr -> let component_name = ATH.Exp.constant ~loc:expr.P.pexp_loc (ATH.Const.string var_pat) @@ -211,22 +220,35 @@ module Declaration_ppx = struct | None -> unmatched_value_binding ) let register attribute = - let open Ppxlib in - Extension.declare (attribute_name attribute) - Extension.Context.structure_item + Ppxlib.Extension.declare (attribute_name attribute) + Ppxlib.Extension.Context.structure_item Ast_pattern.( - pstr - ( pstr_value __ (value_binding ~pat:(ppat_var __) ~expr:__ ^:: nil) - ^:: nil )) + pstr (pstr_value __ (value_binding ~pat:__ ~expr:__ ^:: nil) ^:: nil)) (fun ~loc ~path recursive pat expr -> + let pat, var_name = + let var_name pat = + match_ Ast_pattern.(ppat_var __) loc pat ~with_:(fun name -> name) + in + let var_pat name = + ATH.Pat.var ~loc (Ast_builder.Located.mk ~loc name) + in + match_ + Ast_pattern.(ppat_constraint __ __) + loc pat + ~with_:(fun pat core_type -> + let name = var_name pat in + (Ast_builder.(ppat_constraint ~loc (var_pat name) core_type), name)) + ~on_error:(fun () -> + let name = var_name pat in + (var_pat name, name)) + in let component_name = - ATH.Exp.constant ~loc (ATH.Const.string (path ^ "." ^ pat)) + ATH.Exp.constant ~loc (ATH.Const.string (path ^ "." ^ var_name)) in let transformed_expression = transform_component_expr ~useDynamicKey:false ~attribute ~component_name expr in - let pat = ATH.Pat.var ~loc (Ast_builder.Default.Located.mk ~loc pat) in match recursive with | Recursive -> [%stri let rec [%p pat] = [%e transformed_expression]] | Nonrecursive -> [%stri let [%p pat] = [%e transformed_expression]]) diff --git a/test/Components.re b/test/Components.re index 03d3f7d..ff658d0 100644 --- a/test/Components.re +++ b/test/Components.re @@ -264,3 +264,15 @@ module LocallyAbstractType: { (empty, hooks); }; }; + +// Test to make sure type annotations are accepted +module Pexp_constraint = { + let%component make: (~key: Key.t=?, unit) => element(node) = + ((), hooks) => (empty, hooks); +}; + +// Test to make sure type annotations with universal quantifiers are accepted +module Ppat_constraint = { + let%component make: 'a. (~key: Key.t=?, unit) => element(node) = + ((), hooks) => (empty, hooks); +};