diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fa3963..b80f086 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +unreleased +---------- + +* Add show @printer support for polymorphic variants + #286 + (Simmo Saan and Guillaume Huysmans) + 6.0.2 ----- diff --git a/src_plugins/show/ppx_deriving_show.ml b/src_plugins/show/ppx_deriving_show.ml index f1bf5b0..c8d0145 100644 --- a/src_plugins/show/ppx_deriving_show.ml +++ b/src_plugins/show/ppx_deriving_show.ml @@ -23,6 +23,7 @@ let attr_printer context = Attribute.declare "deriving.show.printer" context Ast_pattern.(single_expr_payload __) (fun e -> e) let ct_attr_printer = attr_printer Attribute.Context.core_type let constr_attr_printer = attr_printer Attribute.Context.constructor_declaration +let rtag_attr_printer = attr_printer Attribute.Context.rtag let ct_attr_polyprinter = Attribute.declare "deriving.show.polyprinter" Attribute.Context.core_type Ast_pattern.(single_expr_payload __) (fun e -> e) @@ -160,21 +161,29 @@ let rec expr_of_typ quoter typ = | { ptyp_desc = Ptyp_variant (fields, _, _); ptyp_loc } -> let cases = fields |> List.map (fun field -> - match field.prf_desc with - | Rtag(label, true (*empty*), []) -> + match Attribute.get rtag_attr_printer field, field.prf_desc with + | Some printer, Rtag(label, true (*empty*), []) -> + let label = label.txt in + Exp.case (Pat.variant label None) + [%expr [%e wrap_printer quoter printer] fmt ()] + | None, Rtag(label, true (*empty*), []) -> let label = label.txt in Exp.case (Pat.variant label None) [%expr Ppx_deriving_runtime.Format.pp_print_string fmt [%e str ("`" ^ label)]] - | Rtag(label, false, [typ]) -> + | Some printer, Rtag(label, false, [typ]) -> + let label = label.txt in + Exp.case (Pat.variant label (Some [%pat? x])) + [%expr [%e wrap_printer quoter printer] fmt x] + | None, Rtag(label, false, [typ]) -> let label = label.txt in Exp.case (Pat.variant label (Some [%pat? x])) [%expr Ppx_deriving_runtime.Format.fprintf fmt [%e str ("`" ^ label ^ " (@[")]; [%e expr_of_typ typ] x; Ppx_deriving_runtime.Format.fprintf fmt "@])"] - | Rinherit({ ptyp_desc = Ptyp_constr (tname, _) } as typ) -> + | _, Rinherit({ ptyp_desc = Ptyp_constr (tname, _) } as typ) -> Exp.case [%pat? [%p Pat.type_ tname] as x] [%expr [%e expr_of_typ typ] x] - | _ -> + | _, _ -> raise_errorf ~loc:ptyp_loc "%s cannot be derived for %s" deriver (Ppx_deriving.string_of_core_type typ)) in diff --git a/src_test/show/test_deriving_show.cppo.ml b/src_test/show/test_deriving_show.cppo.ml index 7a2d755..6ef97d8 100644 --- a/src_test/show/test_deriving_show.cppo.ml +++ b/src_test/show/test_deriving_show.cppo.ml @@ -233,6 +233,25 @@ let test_variant_printer ctxt = assert_equal ~printer "fourth: 8 4" (show_variant_printer (Fourth(8,4))) +type poly_variant_printer = [ + | `First [@printer fun fmt _ -> Format.pp_print_string fmt "first"] + | `Second of int [@printer fun fmt i -> fprintf fmt "second: %d" i] + | `Third + | `Fourth of int * int + [@printer fun fmt (a,b) -> fprintf fmt "fourth: %d %d" a b] +] +[@@deriving show] + +let test_poly_variant_printer ctxt = + assert_equal ~printer + "first" (show_poly_variant_printer `First); + assert_equal ~printer + "second: 42" (show_poly_variant_printer (`Second 42)); + assert_equal ~printer + "`Third" (show_poly_variant_printer `Third); + assert_equal ~printer + "fourth: 8 4" (show_poly_variant_printer (`Fourth(8,4))) + type no_full = NoFull of int [@@deriving show { with_path = false }] type with_full = WithFull of int [@@deriving show { with_path = true }] module WithFull = struct @@ -264,6 +283,7 @@ let suite = "Test deriving(show)" >::: [ "test_std_shadowing" >:: test_std_shadowing; "test_poly_app" >:: test_poly_app; "test_variant_printer" >:: test_variant_printer; + "test_poly_variant_printer" >:: test_poly_variant_printer; "test_paths" >:: test_paths_printer; "test_result" >:: test_result; "test_result_result" >:: test_result_result;