ligo/expander/ppx_let_expander.ml
2019-01-16 10:57:51 +00:00

177 lines
5.1 KiB
OCaml

open Base
open Ppxlib
open Ast_builder.Default
module List = struct
include List
let reduce_exn l ~f =
match l with
| [] -> invalid_arg "List.reduce_exn"
| hd :: tl -> fold_left tl ~init:hd ~f
;;
end
module Extension_name = struct
type t =
| Bind
| Bind_open
| Map
| Map_open
let operator_name = function
| Bind | Bind_open -> "bind"
| Map | Map_open -> "map"
;;
let to_string = function
| Bind -> "bind"
| Bind_open -> "bind_open"
| Map -> "map"
| Map_open -> "map_open"
;;
end
let let_syntax ~modul : Longident.t =
match modul with
| None -> Lident "Let_syntax"
| Some id -> Ldot (id.txt, "Let_syntax")
;;
let open_on_rhs ~loc ~modul =
Located.mk ~loc (Longident.Ldot (let_syntax ~modul, "Open_on_rhs"))
;;
let eoperator ~loc ~modul func =
let lid : Longident.t = Ldot (let_syntax ~modul, func) in
pexp_ident ~loc (Located.mk ~loc lid)
;;
let expand_with_tmp_vars ~loc bindings expr ~f =
match bindings with
| [ _ ] -> f ~loc bindings expr
| _ ->
let tmp_vars =
List.map bindings ~f:(fun _ -> gen_symbol ~prefix:"__let_syntax" ())
in
let s_rhs_tmp_var (* s/rhs/tmp_var *) =
List.map2_exn bindings tmp_vars ~f:(fun vb var ->
{ vb with pvb_expr = evar ~loc:vb.pvb_expr.pexp_loc var })
in
let s_lhs_tmp_var (* s/lhs/tmp_var *) =
List.map2_exn bindings tmp_vars ~f:(fun vb var ->
{ vb with pvb_pat = pvar ~loc:vb.pvb_pat.ppat_loc var })
in
pexp_let ~loc Nonrecursive s_lhs_tmp_var (f ~loc s_rhs_tmp_var expr)
;;
let bind_apply ~loc ~modul extension_name ~arg ~fn =
pexp_apply
~loc
(eoperator ~loc ~modul (Extension_name.operator_name extension_name))
[ Nolabel, arg; Labelled "f", fn ]
;;
let maybe_open extension_name ~to_open:module_to_open expr =
let loc = expr.pexp_loc in
match (extension_name : Extension_name.t) with
| Bind | Map -> expr
| Bind_open | Map_open -> pexp_open ~loc Override (module_to_open ~loc) expr
;;
let expand_let extension_name ~loc ~modul bindings body =
if List.is_empty bindings
then invalid_arg "expand_let: list of bindings must be non-empty";
(* Build expression [both E1 (both E2 (both ...))] *)
let nested_boths =
let rev_boths = List.rev_map bindings ~f:(fun vb -> vb.pvb_expr) in
List.reduce_exn rev_boths ~f:(fun acc e ->
let loc = e.pexp_loc in
eapply ~loc (eoperator ~loc ~modul "both") [ e; acc ])
in
(* Build pattern [(P1, (P2, ...))] *)
let nested_patterns =
let rev_patts = List.rev_map bindings ~f:(fun vb -> vb.pvb_pat) in
List.reduce_exn rev_patts ~f:(fun acc p ->
let loc = p.ppat_loc in
ppat_tuple ~loc [ p; acc ])
in
bind_apply
~loc
~modul
extension_name
~arg:nested_boths
~fn:(pexp_fun ~loc Nolabel None nested_patterns body)
;;
let expand_match extension_name ~loc ~modul expr cases =
bind_apply
~loc
~modul
extension_name
~arg:(maybe_open extension_name ~to_open:(open_on_rhs ~modul) expr)
~fn:(pexp_function ~loc cases)
;;
let expand_if extension_name ~loc expr then_ else_ =
expand_match
extension_name
~loc
expr
[ case ~lhs:(pbool ~loc true) ~guard:None ~rhs:then_
; case ~lhs:(pbool ~loc false) ~guard:None ~rhs:else_
]
;;
let expand ~modul extension_name expr =
let loc = expr.pexp_loc in
let expansion =
match expr.pexp_desc with
| Pexp_let (Nonrecursive, bindings, expr) ->
let bindings =
List.map bindings ~f:(fun vb ->
let pvb_pat =
(* Temporary hack tentatively detecting that the parser
has expanded `let x : t = e` into `let x : t = (e : t)`.
For reference, here is the relevant part of the parser:
https://github.com/ocaml/ocaml/blob/4.07/parsing/parser.mly#L1628 *)
match vb.pvb_pat.ppat_desc, vb.pvb_expr.pexp_desc with
| ( Ppat_constraint (p, { ptyp_desc = Ptyp_poly ([], t1); _ })
, Pexp_constraint (_, t2) )
when phys_equal t1 t2 -> p
| _ -> vb.pvb_pat
in
{ vb with
pvb_pat
; pvb_expr =
maybe_open extension_name ~to_open:(open_on_rhs ~modul) vb.pvb_expr
})
in
expand_with_tmp_vars ~loc bindings expr ~f:(expand_let extension_name ~modul)
| Pexp_let (Recursive, _, _) ->
Location.raise_errorf
~loc
"'let%%%s' may not be recursive"
(Extension_name.to_string extension_name)
| Pexp_match (expr, cases) -> expand_match extension_name ~loc ~modul expr cases
| Pexp_ifthenelse (expr, then_, else_) ->
let else_ =
match else_ with
| Some else_ -> else_
| None ->
Location.raise_errorf
~loc
"'if%%%s' must include an else branch"
(Extension_name.to_string extension_name)
in
expand_if extension_name ~loc ~modul expr then_ else_
| _ ->
Location.raise_errorf
~loc
"'%%%s' can only be used with 'let', 'match', and 'if'"
(Extension_name.to_string extension_name)
in
{ expansion with pexp_attributes = expr.pexp_attributes @ expansion.pexp_attributes }
;;