From 72f5698c3ddb5e53deafcaf69f45d3d4a4e0184f Mon Sep 17 00:00:00 2001 From: Galfour Date: Mon, 22 Apr 2019 18:15:39 +0000 Subject: [PATCH] add variants --- src/lib_utils/PP.ml | 2 + src/lib_utils/tree.ml | 1 + src/lib_utils/x_tezos_micheline.ml | 7 ++ src/ligo/ast_simplified/PP.ml | 6 ++ src/ligo/ast_simplified/combinators.ml | 2 + src/ligo/ast_simplified/types.ml | 2 + src/ligo/ast_typed/PP.ml | 6 ++ src/ligo/ast_typed/combinators.ml | 30 +++++++ src/ligo/ast_typed/misc.ml | 6 +- src/ligo/ast_typed/types.ml | 6 +- src/ligo/bin/cli.ml | 43 +++++++--- src/ligo/compiler/compiler.ml | 2 + src/ligo/compiler/compiler_program.ml | 85 +++++++++++++++++- src/ligo/contracts/match.ligo | 9 +- src/ligo/contracts/super-counter.ligo | 10 +++ src/ligo/contracts/variant.ligo | 13 +++ src/ligo/main/contract.ml | 109 ++++++++++++++++++++++++ src/ligo/main/main.ml | 2 + src/ligo/mini_c/PP.ml | 11 ++- src/ligo/mini_c/combinators.ml | 14 +++ src/ligo/mini_c/mini_c.ml | 1 + src/ligo/mini_c/types.ml | 3 + src/ligo/parser/pascaligo/AST.ml | 2 + src/ligo/parser/pascaligo/AST.mli | 1 + src/ligo/parser/pascaligo/Parser.mly | 17 +++- src/ligo/parser/pascaligo/ParserLog.ml | 6 ++ src/ligo/parser/pascaligo/ParserLog.mli | 1 + src/ligo/simplify/pascaligo.ml | 73 +++++++++++----- src/ligo/test/integration_tests.ml | 34 +++++++- src/ligo/test/test_helpers.ml | 4 +- src/ligo/transpiler/transpiler.ml | 94 +++++++++++++++++--- src/ligo/typer/typer.ml | 76 +++++++++++++++-- 32 files changed, 609 insertions(+), 69 deletions(-) create mode 100644 src/ligo/contracts/super-counter.ligo create mode 100644 src/ligo/contracts/variant.ligo diff --git a/src/lib_utils/PP.ml b/src/lib_utils/PP.ml index a32854c6f..b82249812 100644 --- a/src/lib_utils/PP.ml +++ b/src/lib_utils/PP.ml @@ -23,6 +23,8 @@ let option = fun f ppf opt -> | Some x -> fprintf ppf "Some(%a)" f x | None -> fprintf ppf "None" +let int = fun ppf n -> fprintf ppf "%d" n + let map = fun f pp ppf x -> pp ppf (f x) diff --git a/src/lib_utils/tree.ml b/src/lib_utils/tree.ml index 7b4c5886a..1893e57a6 100644 --- a/src/lib_utils/tree.ml +++ b/src/lib_utils/tree.ml @@ -94,6 +94,7 @@ module Append = struct | Empty -> empty | Full x -> fold' leaf node x + let rec assoc_opt' : ('a * 'b) t' -> 'a -> 'b option = fun t k -> match t with | Leaf (k', v) when k = k' -> Some v diff --git a/src/lib_utils/x_tezos_micheline.ml b/src/lib_utils/x_tezos_micheline.ml index 2219c74c0..ff497414f 100644 --- a/src/lib_utils/x_tezos_micheline.ml +++ b/src/lib_utils/x_tezos_micheline.ml @@ -18,6 +18,13 @@ module Michelson = struct let i_comment s : michelson = prim ~annot:["\"" ^ s ^ "\""] I_NOP + let contract parameter storage code = + seq [ + prim ~children:[parameter] K_parameter ; + prim ~children:[storage] K_storage ; + prim ~children:[code] K_code ; + ] + let int n : michelson = Int (0, n) let string s : michelson = String (0, s) let bytes s : michelson = Bytes (0, s) diff --git a/src/ligo/ast_simplified/PP.ml b/src/ligo/ast_simplified/PP.ml index b6d7beb0a..a28cc0bbe 100644 --- a/src/ligo/ast_simplified/PP.ml +++ b/src/ligo/ast_simplified/PP.ml @@ -66,10 +66,16 @@ and block ppf (b:block) = (list_sep instruction (tag "@;")) ppf b and single_record_patch ppf ((p, ae) : string * ae) = fprintf ppf "%s <- %a" p annotated_expression ae +and matching_variant_case : type a . (_ -> a -> unit) -> _ -> (constructor_name * name) * a -> unit = + fun f ppf ((c,n),a) -> + fprintf ppf "| %s %s -> %a" c n f a + and matching : type a . (formatter -> a -> unit) -> formatter -> a matching -> unit = fun f ppf m -> match m with | Match_tuple (lst, b) -> fprintf ppf "let (%a) = %a" (list_sep_d string) lst f b + | Match_variant lst -> + fprintf ppf "%a" (list_sep (matching_variant_case f) (tag "@.")) lst | Match_bool {match_true ; match_false} -> fprintf ppf "| True -> %a @.| False -> %a" f match_true f match_false | Match_list {match_nil ; match_cons = (hd, tl, match_cons)} -> diff --git a/src/ligo/ast_simplified/combinators.ml b/src/ligo/ast_simplified/combinators.ml index e7b6986f9..833e70f71 100644 --- a/src/ligo/ast_simplified/combinators.ml +++ b/src/ligo/ast_simplified/combinators.ml @@ -56,11 +56,13 @@ let e_none : expression = E_constant ("NONE", []) let e_map lst : expression = E_map lst let e_list lst : expression = E_list lst let e_pair a b : expression = E_tuple [a; b] +let e_constructor s a : expression = E_constructor (s , a) let e_a_int n : annotated_expression = make_e_a_full (e_int n) t_int let e_a_nat n : annotated_expression = make_e_a_full (e_nat n) t_nat let e_a_bool b : annotated_expression = make_e_a_full (e_bool b) t_bool let e_a_unit : annotated_expression = make_e_a_full (e_unit ()) t_unit +let e_a_constructor s a : annotated_expression = make_e_a (e_constructor s a) let e_a_record r = let type_annotation = Option.( diff --git a/src/ligo/ast_simplified/types.ml b/src/ligo/ast_simplified/types.ml index fc1472b06..d6325ce66 100644 --- a/src/ligo/ast_simplified/types.ml +++ b/src/ligo/ast_simplified/types.ml @@ -1,5 +1,6 @@ type name = string type type_name = string +type constructor_name = string type 'a name_map = 'a Map.String.t type 'a type_name_map = 'a Map.String.t @@ -109,6 +110,7 @@ and 'a matching = match_some : name * 'a ; } | Match_tuple of name list * 'a + | Match_variant of ((constructor_name * name) * 'a) list and matching_instr = b matching diff --git a/src/ligo/ast_typed/PP.ml b/src/ligo/ast_typed/PP.ml index 00d25a98d..d319ae0ed 100644 --- a/src/ligo/ast_typed/PP.ml +++ b/src/ligo/ast_typed/PP.ml @@ -66,9 +66,15 @@ and block ppf (b:block) = (list_sep instruction (tag "@;")) ppf b and single_record_patch ppf ((s, ae) : string * ae) = fprintf ppf "%s <- %a" s annotated_expression ae +and matching_variant_case : type a . (_ -> a -> unit) -> _ -> (constructor_name * name) * a -> unit = + fun f ppf ((c,n),a) -> + fprintf ppf "| %s %s -> %a" c n f a + and matching : type a . (formatter -> a -> unit) -> _ -> a matching -> unit = fun f ppf m -> match m with | Match_tuple (lst, b) -> fprintf ppf "let (%a) = %a" (list_sep_d (fun ppf -> fprintf ppf "%s")) lst f b + | Match_variant (lst , _) -> + fprintf ppf "%a" (list_sep (matching_variant_case f) (tag "@.")) lst | Match_bool {match_true ; match_false} -> fprintf ppf "| True -> %a @.| False -> %a" f match_true f match_false | Match_list {match_nil ; match_cons = (hd, tl, match_cons)} -> diff --git a/src/ligo/ast_typed/combinators.ml b/src/ligo/ast_typed/combinators.ml index 5d049c9d1..5caa7193c 100644 --- a/src/ligo/ast_typed/combinators.ml +++ b/src/ligo/ast_typed/combinators.ml @@ -55,6 +55,18 @@ let get_t_tuple (t:type_value) : type_value list result = match t.type_value' wi | T_tuple lst -> ok lst | _ -> simple_fail "not a tuple" +let get_t_pair (t:type_value) : (type_value * type_value) result = match t.type_value' with + | T_tuple lst -> + let%bind () = + trace_strong (simple_error "not a pair") @@ + Assert.assert_list_size lst 2 in + ok List.(nth lst 0 , nth lst 1) + | _ -> simple_fail "not a tuple" + +let get_t_function (t:type_value) : (type_value * type_value) result = match t.type_value' with + | T_function ar -> ok ar + | _ -> simple_fail "not a tuple" + let get_t_sum (t:type_value) : type_value SMap.t result = match t.type_value' with | T_sum m -> ok m | _ -> simple_fail "not a sum" @@ -67,6 +79,7 @@ let get_t_map (t:type_value) : (type_value * type_value) result = match t.type_value' with | T_constant ("map", [k;v]) -> ok (k, v) | _ -> simple_fail "get: not a map" + let assert_t_map (t:type_value) : unit result = match t.type_value' with | T_constant ("map", [_ ; _]) -> ok () @@ -77,6 +90,15 @@ let assert_t_list (t:type_value) : unit result = | T_constant ("list", [_]) -> ok () | _ -> simple_fail "assert: not a list" +let assert_t_operation (t:type_value) : unit result = + match t.type_value' with + | T_constant ("operation" , []) -> ok () + | _ -> simple_fail "assert: not an operation" + +let assert_t_list_operation (t : type_value) : unit result = + let%bind t' = get_t_list t in + assert_t_operation t' + let assert_t_int : type_value -> unit result = fun t -> match t.type_value' with | T_constant ("int", []) -> ok () | _ -> simple_fail "not an int" @@ -146,6 +168,14 @@ let get_a_bool (t:annotated_expression) = | E_literal (Literal_bool b) -> ok b | _ -> simple_fail "not a bool" +let get_declaration_by_name : program -> string -> declaration result = fun p name -> + let aux : declaration -> bool = fun declaration -> + match declaration with + | Declaration_constant d -> d.name = name + in + trace_option (simple_error "no declaration with given name") @@ + List.find_opt aux @@ List.map Location.unwrap p + open Environment let env_sum_type ?(env = full_empty) ?(name = "a_sum_type") diff --git a/src/ligo/ast_typed/misc.ml b/src/ligo/ast_typed/misc.ml index f4c28145a..81e6418e7 100644 --- a/src/ligo/ast_typed/misc.ml +++ b/src/ligo/ast_typed/misc.ml @@ -85,12 +85,16 @@ module Free_variables = struct let (_ , frees) = block' b bl in frees + and matching_variant_case : type a . (bindings -> a -> bindings) -> bindings -> ((constructor_name * name) * a) -> bindings = fun f b ((_,n),c) -> + f (union (singleton n) b) c + and matching : type a . (bindings -> a -> bindings) -> bindings -> a matching -> bindings = fun f b m -> match m with | Match_bool { match_true = t ; match_false = fa } -> union (f b t) (f b fa) | Match_list { match_nil = n ; match_cons = (hd, tl, c) } -> union (f b n) (f (union (of_list [hd ; tl]) b) c) | Match_option { match_none = n ; match_some = ((opt, _), s) } -> union (f b n) (f (union (singleton opt) b) s) - | Match_tuple (lst, a) -> f (union (of_list lst) b) a + | Match_tuple (lst , a) -> f (union (of_list lst) b) a + | Match_variant (lst , _) -> unions @@ List.map (matching_variant_case f b) lst and matching_expression = fun x -> matching annotated_expression x diff --git a/src/ligo/ast_typed/types.ml b/src/ligo/ast_typed/types.ml index d226c2c5c..a9b4bd1aa 100644 --- a/src/ligo/ast_typed/types.ml +++ b/src/ligo/ast_typed/types.ml @@ -6,6 +6,7 @@ module SMap = Map.String type name = string type type_name = string +type constructor_name = string type 'a name_map = 'a SMap.t type 'a type_name_map = 'a SMap.t @@ -47,7 +48,7 @@ and type_value' = | T_sum of tv_map | T_record of tv_map | T_constant of type_name * tv list - | T_function of tv * tv + | T_function of (tv * tv) and type_value = { type_value' : type_value' ; @@ -128,7 +129,8 @@ and 'a matching = match_none : 'a ; match_some : (name * type_value) * 'a ; } - | Match_tuple of name list * 'a + | Match_tuple of (name list * 'a) + | Match_variant of (((constructor_name * name) * 'a) list * type_value) and matching_instr = b matching diff --git a/src/ligo/bin/cli.ml b/src/ligo/bin/cli.ml index 2c1cd6ffa..64ca6bbe9 100644 --- a/src/ligo/bin/cli.ml +++ b/src/ligo/bin/cli.ml @@ -13,22 +13,37 @@ let main () = then simple_fail "Pass a command" else ok () in let command = Sys.argv.(1) in - (* Format.printf "Processing command %s (%d)\n" command l ; *) match command with | "compile" -> ( - let%bind () = - if l <> 4 - then simple_fail "Bad number of argument to compile" - else ok () in - let source = Sys.argv.(2) in - let entry_point = Sys.argv.(3) in - (* Format.printf "Compiling %s from %s\n%!" entry_point source ; *) - let%bind michelson = - trace (simple_error "compile michelson") @@ - Ligo.compile_file source entry_point in - Format.printf "Program : %a\n" Micheline.Michelson.pp michelson ; - ok () + let sub_command = Sys.argv.(2) in + match sub_command with + | "file" -> ( + let%bind () = + trace_strong (simple_error "bad number of args") @@ + Assert.assert_equal_int 5 l in + let source = Sys.argv.(3) in + let entry_point = Sys.argv.(4) in + let%bind contract = + trace (simple_error "compile michelson") @@ + Ligo.Contract.compile_contract_file source entry_point in + Format.printf "Contract:\n%s\n" contract ; + ok () + ) + | "expression" -> ( + let%bind () = + trace_strong (simple_error "bad number of args") @@ + Assert.assert_equal_int 6 l in + let source = Sys.argv.(3) in + let entry_point = Sys.argv.(4) in + let expression = Sys.argv.(5) in + let%bind value = + trace (simple_error "compile expression") @@ + Ligo.Contract.compile_contract_parameter source entry_point expression in + Format.printf "Input:\n%s\n" value; + ok () + ) + | _ -> simple_fail "Bad sub-command" ) - | _ -> simple_fail "Bad command" + | _ -> simple_fail "Bad command" let () = toplevel @@ main () diff --git a/src/ligo/compiler/compiler.ml b/src/ligo/compiler/compiler.ml index 1a306f0dc..fbdd8942a 100644 --- a/src/ligo/compiler/compiler.ml +++ b/src/ligo/compiler/compiler.ml @@ -2,3 +2,5 @@ module Uncompiler = Uncompiler module Program = Compiler_program module Type = Compiler_type module Environment = Compiler_environment + +include Program diff --git a/src/ligo/compiler/compiler_program.ml b/src/ligo/compiler/compiler_program.ml index 44f27a73f..ef4f661f0 100644 --- a/src/ligo/compiler/compiler_program.ml +++ b/src/ligo/compiler/compiler_program.ml @@ -10,7 +10,7 @@ open Memory_proto_alpha.Script_ir_translator open Operators.Compiler -let get_predicate : string -> expression list -> predicate result = fun s lst -> +let get_predicate : string -> type_value -> expression list -> predicate result = fun s ty lst -> match Map.String.find_opt s Operators.Compiler.predicates with | Some x -> ok x | None -> ( @@ -23,6 +23,18 @@ let get_predicate : string -> expression list -> predicate result = fun s lst -> | _ -> simple_fail "mini_c . MAP_REMOVE" in let%bind v_ty = Compiler_type.type_ v in ok @@ simple_binary @@ seq [dip (i_none v_ty) ; prim I_UPDATE ] + | "LEFT" -> + let%bind r = match lst with + | [ _ ] -> get_t_right ty + | _ -> simple_fail "mini_c . LEFT" in + let%bind r_ty = Compiler_type.type_ r in + ok @@ simple_unary @@ prim ~children:[r_ty] I_LEFT + | "RIGHT" -> + let%bind l = match lst with + | [ _ ] -> get_t_left ty + | _ -> simple_fail "mini_c . RIGHT" in + let%bind l_ty = Compiler_type.type_ l in + ok @@ simple_unary @@ prim ~children:[l_ty] I_RIGHT | x -> simple_fail ("predicate \"" ^ x ^ "\" doesn't exist") ) @@ -181,7 +193,7 @@ and translate_expression ?(first=false) (expr:expression) : michelson result = let first = first && i = 0 in translate_expression ~first e in bind_list @@ List.mapi aux lst in - let%bind predicate = get_predicate str lst in + let%bind predicate = get_predicate str ty lst in let%bind code = match (predicate, List.length lst) with | Constant c, 0 -> ok @@ virtual_push_first @@ seq @@ lst' @ [ c ; @@ -264,6 +276,58 @@ and translate_expression ?(first=false) (expr:expression) : michelson result = ]) in return code ) + | E_if_none (c, n, (_ , s)) -> ( + let%bind c' = translate_expression c in + let%bind n' = translate_expression n in + let%bind s' = translate_expression s in + let%bind restrict = Compiler_environment.to_michelson_restrict s.environment in + let%bind code = ok (seq [ + c' ; i_unpair ; + i_if_none n' (seq [ + i_pair ; + s' ; + restrict ; + ]) + ; + ]) in + return code + ) + | E_if_left (c, (_ , l), (_ , r)) -> ( + let%bind c' = translate_expression c in + let%bind l' = translate_expression l in + let%bind r' = translate_expression r in + let%bind restrict_l = Compiler_environment.to_michelson_restrict l.environment in + let%bind restrict_r = Compiler_environment.to_michelson_restrict r.environment in + let%bind code = ok (seq [ + c' ; i_unpair ; + i_if_none (seq [ + i_pair ; + l' ; + i_unpair ; + dip restrict_l ; + ]) (seq [ + i_pair ; + r' ; + i_unpair ; + dip restrict_r ; + ]) + ; + ]) in + return code + ) + | E_let_in (_, expr , body) -> ( + let%bind expr' = translate_expression expr in + let%bind body' = translate_expression body in + let%bind restrict = Compiler_environment.to_michelson_restrict body.environment in + let%bind code = ok (seq [ + expr' ; + i_unpair ; + i_swap ; dip i_pair ; + body' ; + restrict ; + ]) in + return code + ) in ok code @@ -277,7 +341,7 @@ and translate_statement ((s', w_env) as s:statement) : michelson result = | S_environment_restrict -> Compiler_environment.to_michelson_restrict w_env.pre_environment | S_environment_add _ -> - simple_fail "not ready yet" + simple_fail "add not ready yet" (* | S_environment_add (name, tv) -> * Environment.to_michelson_add (name, tv) w_env.pre_environment *) | S_declaration (s, expr) -> @@ -490,7 +554,7 @@ type compiled_program = { body : michelson ; } -let translate_program (p:program) (entry:string) : compiled_program result = +let get_main : program -> string -> anon_function_content result = fun p entry -> let is_main (((name , expr), _):toplevel_statement) = match Combinators.Expression.(get_content expr , get_type expr)with | (E_function f , T_function _) @@ -505,12 +569,25 @@ let translate_program (p:program) (entry:string) : compiled_program result = trace_option (simple_error "no functional entry") @@ Tezos_utils.List.find_map is_main p in + ok main + +let translate_program (p:program) (entry:string) : compiled_program result = + let%bind main = get_main p entry in let {input;output} : anon_function_content = main in let%bind body = translate_quote_body main in let%bind input = Compiler_type.Ty.type_ input in let%bind output = Compiler_type.Ty.type_ output in ok ({input;output;body}:compiled_program) +let translate_contract : program -> string -> michelson result = fun p e -> + let%bind main = get_main p e in + let%bind (param_ty , storage_ty) = Combinators.get_t_pair main.input in + let%bind param_michelson = Compiler_type.type_ param_ty in + let%bind storage_michelson = Compiler_type.type_ storage_ty in + let%bind { body = code } = translate_program p e in + let contract = Michelson.contract param_michelson storage_michelson code in + ok contract + let translate_entry (p:anon_function) : compiled_program result = let {input;output} : anon_function_content = p.content in let%bind body = diff --git a/src/ligo/contracts/match.ligo b/src/ligo/contracts/match.ligo index 32ea91625..57a74d7dd 100644 --- a/src/ligo/contracts/match.ligo +++ b/src/ligo/contracts/match.ligo @@ -12,7 +12,7 @@ function match_option (const o : option(int)) : int is begin case o of | None -> skip - | Some(s) -> skip // result := s + | Some(s) -> result := s end end with result @@ -22,3 +22,10 @@ function match_expr_bool (const i : int) : int is | True -> 42 | False -> 0 end + +function match_expr_option (const o : option(int)) : int is + begin skip end with + case o of + | None -> 42 + | Some(s) -> s + end diff --git a/src/ligo/contracts/super-counter.ligo b/src/ligo/contracts/super-counter.ligo new file mode 100644 index 000000000..c8f053d18 --- /dev/null +++ b/src/ligo/contracts/super-counter.ligo @@ -0,0 +1,10 @@ +type action = +| Increment of int +| Decrement of int + +function main (const p : action ; const s : int) : (list(operation) * int) is + block {skip} with ((nil : operation), + match p with + | Increment n -> s + n + | Decrement n -> s - n + end) diff --git a/src/ligo/contracts/variant.ligo b/src/ligo/contracts/variant.ligo new file mode 100644 index 000000000..4ccb21418 --- /dev/null +++ b/src/ligo/contracts/variant.ligo @@ -0,0 +1,13 @@ +type foobar is +| Foo of int +| Bar of bool + +const foo : foobar = Foo (42) + +const bar : foobar = Bar (True) + +function fb(const p : foobar) : int is + block { skip } with (case p of + | Foo (n) -> n + | Bar (t) -> 42 + end) diff --git a/src/ligo/main/contract.ml b/src/ligo/main/contract.ml index 57434b8df..795bb2a7c 100644 --- a/src/ligo/main/contract.ml +++ b/src/ligo/main/contract.ml @@ -14,3 +14,112 @@ include struct trace_strong (simple_error "no entry-point with given name") @@ Assert.assert_true @@ List.exists aux @@ List.map Location.unwrap program end + +include struct + open Ast_typed + open Combinators + + let assert_entry_point_type : type_value -> unit result = fun t -> + let%bind (arg , result) = + trace_strong (simple_error "entry-point doesn't have a function type") @@ + get_t_function t in + let%bind (_ , storage_param) = + trace_strong (simple_error "entry-point doesn't have 2 parameters") @@ + get_t_pair arg in + let%bind (ops , storage_result) = + trace_strong (simple_error "entry-point doesn't have 2 results") @@ + get_t_pair result in + let%bind () = + trace_strong (simple_error "entry-point doesn't have a list of operation as first result") @@ + assert_t_list_operation ops in + let%bind () = + trace_strong (simple_error "entry-point doesn't identitcal type (storage) for second parameter and second result") @@ + assert_type_value_eq (storage_param , storage_result) in + ok () + + let assert_valid_entry_point : program -> string -> unit result = fun p e -> + let%bind declaration = get_declaration_by_name p e in + match declaration with + | Declaration_constant d -> assert_entry_point_type d.annotated_expression.type_annotation +end + +let transpile_value + (e:Ast_typed.annotated_expression) : Mini_c.value result = + let%bind f = + let open Transpiler in + let (f, t) = functionalize e in + let%bind main = translate_main f t in + ok main + in + + let input = Mini_c.Combinators.d_unit in + let%bind r = Run_mini_c.run_entry f input in + ok r + +let compile_contract_file : string -> string -> string result = fun source entry_point -> + let%bind raw = + trace (simple_error "parsing") @@ + Parser.parse_file source in + let%bind simplified = + trace (simple_error "simplifying") @@ + Simplify.Pascaligo.simpl_program raw in + let%bind () = + assert_entry_point_defined simplified entry_point in + let%bind typed = + trace (simple_error "typing") @@ + Typer.type_program simplified in + let%bind () = + assert_valid_entry_point typed entry_point in + let%bind mini_c = + trace (simple_error "transpiling") @@ + Transpiler.translate_program typed in + let%bind michelson = + trace (simple_error "compiling") @@ + Compiler.translate_contract mini_c entry_point in + let str = + Format.asprintf "%a" Micheline.Michelson.pp michelson in + ok str + +let compile_contract_parameter : string -> string -> string -> string result = fun source entry_point expression -> + let%bind parameter_tv = + let%bind raw = + trace (simple_error "parsing file") @@ + Parser.parse_file source in + let%bind simplified = + trace (simple_error "simplifying file") @@ + Simplify.Pascaligo.simpl_program raw in + let%bind () = + assert_entry_point_defined simplified entry_point in + let%bind typed = + trace (simple_error "typing file") @@ + Typer.type_program simplified in + let%bind () = + assert_valid_entry_point typed entry_point in + let%bind declaration = Ast_typed.Combinators.get_declaration_by_name typed entry_point in + match declaration with + | Declaration_constant d -> ok d.annotated_expression.type_annotation + in + let%bind expr = + let%bind raw = + trace (simple_error "parsing expression") @@ + Parser.parse_expression expression in + let%bind simplified = + trace (simple_error "simplifying expression") @@ + Simplify.Pascaligo.simpl_expression raw in + let%bind typed = + trace (simple_error "typing expression") @@ + Typer.type_annotated_expression Ast_typed.Environment.full_empty simplified in + let%bind () = + trace (simple_error "expression type doesn't match type parameter") @@ + Ast_typed.assert_type_value_eq (parameter_tv , typed.type_annotation) in + let%bind mini_c = + trace (simple_error "transpiling expression") @@ + transpile_value typed in + let%bind michelson = + trace (simple_error "compiling expression") @@ + Compiler.translate_value mini_c in + let str = + Format.asprintf "%a" Micheline.Michelson.pp michelson in + ok str + in + ok expr diff --git a/src/ligo/main/main.ml b/src/ligo/main/main.ml index df9cfeb67..bc45dd51f 100644 --- a/src/ligo/main/main.ml +++ b/src/ligo/main/main.ml @@ -178,3 +178,5 @@ let compile_file (source: string) (entry_point:string) : Micheline.Michelson.t r trace (simple_error "compiling") @@ compile mini_c entry_point in ok michelson + +module Contract = Contract diff --git a/src/ligo/mini_c/PP.ml b/src/ligo/mini_c/PP.ml index fdcf39524..df09d281e 100644 --- a/src/ligo/mini_c/PP.ml +++ b/src/ligo/mini_c/PP.ml @@ -72,14 +72,19 @@ and expression' ppf (e:expression') = match e with | E_empty_list _ -> fprintf ppf "list[]" | E_make_none _ -> fprintf ppf "none" | E_Cond (c, a, b) -> fprintf ppf "%a ? %a : %a" expression c expression a expression b + | E_if_none (c, n, ((name, _) , s)) -> fprintf ppf "%a ?? %a : %s -> %a" expression c expression n name expression s + | E_if_left (c, ((name_l, _) , l), ((name_r, _) , r)) -> + fprintf ppf "%a ?? %s -> %a : %s -> %a" expression c name_l expression l name_r expression r + | E_let_in ((name , _) , expr , body) -> + fprintf ppf "let %s = %a in %a" name expression expr expression body and expression : _ -> expression -> _ = fun ppf e -> - expression' ppf (Combinators.Expression.get_content e) + expression' ppf e.content and expression_with_type : _ -> expression -> _ = fun ppf e -> fprintf ppf "%a : %a" - expression' (Combinators.Expression.get_content e) - type_ (Combinators.Expression.get_type e) + expression' e.content + type_ e.type_value and function_ ppf ({binder ; input ; output ; body ; result ; capture_type}:anon_function_content) = fprintf ppf "fun[%s] (%s:%a) : %a %a return %a" diff --git a/src/ligo/mini_c/combinators.ml b/src/ligo/mini_c/combinators.ml index 53ceaf8de..e40c27efb 100644 --- a/src/ligo/mini_c/combinators.ml +++ b/src/ligo/mini_c/combinators.ml @@ -95,6 +95,19 @@ let get_or (v:value) = match v with | D_right b -> ok (true, b) | _ -> simple_fail "not a left/right" +let wrong_type name t = + let title () = "not a " ^ name in + let content () = Format.asprintf "%a" PP.type_ t in + error title content + +let get_t_left t = match t with + | T_or (a , _) -> ok a + | _ -> fail @@ wrong_type "union" t + +let get_t_right t = match t with + | T_or (_ , b) -> ok b + | _ -> fail @@ wrong_type "union" t + let get_last_statement ((b', _):block) : statement result = let aux lst = match lst with | [] -> simple_fail "get_last: empty list" @@ -107,6 +120,7 @@ let t_nat : type_value = T_base Base_nat let t_function x y : type_value = T_function ( x , y ) let t_shallow_closure x y z : type_value = T_shallow_closure ( x , y , z ) let t_pair x y : type_value = T_pair ( x , y ) +let t_union x y : type_value = T_or ( x , y ) let quote binder input output body result : anon_function = let content : anon_function_content = { diff --git a/src/ligo/mini_c/mini_c.ml b/src/ligo/mini_c/mini_c.ml index c919abaae..5f4e9f5a2 100644 --- a/src/ligo/mini_c/mini_c.ml +++ b/src/ligo/mini_c/mini_c.ml @@ -6,4 +6,5 @@ module Combinators = struct include Combinators include Combinators_smart end +include Combinators module Environment = Environment diff --git a/src/ligo/mini_c/types.ml b/src/ligo/mini_c/types.ml index 5be6bb02a..3ba6a7571 100644 --- a/src/ligo/mini_c/types.ml +++ b/src/ligo/mini_c/types.ml @@ -64,6 +64,9 @@ and expression' = | E_empty_list of type_value | E_make_none of type_value | E_Cond of expression * expression * expression + | E_if_none of expression * expression * ((var_name * type_value) * expression) + | E_if_left of expression * ((var_name * type_value) * expression) * ((var_name * type_value) * expression) + | E_let_in of ((var_name * type_value) * expression * expression) and expression = { content : expression' ; diff --git a/src/ligo/parser/pascaligo/AST.ml b/src/ligo/parser/pascaligo/AST.ml index 6b89d2434..0a7b05375 100644 --- a/src/ligo/parser/pascaligo/AST.ml +++ b/src/ligo/parser/pascaligo/AST.ml @@ -646,6 +646,7 @@ and arguments = tuple_injection and pattern = PCons of (pattern, cons) nsepseq reg +| PConstr of (constr * pattern reg) reg | PVar of Lexer.lexeme reg | PWild of wild | PInt of (Lexer.lexeme * Z.t) reg @@ -792,6 +793,7 @@ let pattern_to_region = function | PList Sugar {region; _} | PList PNil region | PList Raw {region; _} +| PConstr {region; _} | PTuple {region; _} -> region let local_decl_to_region = function diff --git a/src/ligo/parser/pascaligo/AST.mli b/src/ligo/parser/pascaligo/AST.mli index 6901c2607..92ad6829d 100644 --- a/src/ligo/parser/pascaligo/AST.mli +++ b/src/ligo/parser/pascaligo/AST.mli @@ -630,6 +630,7 @@ and arguments = tuple_injection and pattern = PCons of (pattern, cons) nsepseq reg +| PConstr of (constr * pattern reg) reg | PVar of Lexer.lexeme reg | PWild of wild | PInt of (Lexer.lexeme * Z.t) reg diff --git a/src/ligo/parser/pascaligo/Parser.mly b/src/ligo/parser/pascaligo/Parser.mly index e22e6e2a7..43209aa8f 100644 --- a/src/ligo/parser/pascaligo/Parser.mly +++ b/src/ligo/parser/pascaligo/Parser.mly @@ -197,9 +197,9 @@ type_tuple: par(nsepseq(type_expr,COMMA)) { $1 } sum_type: - nsepseq(variant,VBAR) { - let region = nsepseq_to_region (fun x -> x.region) $1 - in {region; value = $1}} + option(VBAR) nsepseq(variant,VBAR) { + let region = nsepseq_to_region (fun x -> x.region) $2 + in {region; value = $2}} variant: Constr Of cartesian { @@ -1092,6 +1092,7 @@ core_pattern: | C_None { PNone $1 } | list_patt { PList $1 } | tuple_patt { PTuple $1 } +| constr_patt { PConstr $1 } | C_Some par(core_pattern) { let region = cover $1 $2.region in PSome {region; value = $1,$2}} @@ -1106,3 +1107,13 @@ cons_pattern: tuple_patt: par(nsepseq(core_pattern,COMMA)) { $1 } + +constr_patt: + Constr core_pattern { + let second = + let region = pattern_to_region $2 in + {region; value=$2} + in + let region = cover $1.region second.region in + let value = ($1 , second) in + {region; value}} diff --git a/src/ligo/parser/pascaligo/ParserLog.ml b/src/ligo/parser/pascaligo/ParserLog.ml index f78edc7a2..08ea20431 100644 --- a/src/ligo/parser/pascaligo/ParserLog.ml +++ b/src/ligo/parser/pascaligo/ParserLog.ml @@ -682,6 +682,12 @@ and print_pattern = function | PSome psome -> print_psome psome | PList pattern -> print_list_pattern pattern | PTuple ptuple -> print_ptuple ptuple +| PConstr pattern -> print_constr_pattern pattern + +and print_constr_pattern {value; _} = + let (constr, args) = value in + print_constr constr ; + print_pattern args.value ; and print_psome {value; _} = let c_Some, patterns = value in diff --git a/src/ligo/parser/pascaligo/ParserLog.mli b/src/ligo/parser/pascaligo/ParserLog.mli index 9211b081a..637a15438 100644 --- a/src/ligo/parser/pascaligo/ParserLog.mli +++ b/src/ligo/parser/pascaligo/ParserLog.mli @@ -6,3 +6,4 @@ val mode : [`Byte | `Point] ref val print_tokens : AST.t -> unit val print_path : AST.path -> unit +val print_pattern : AST.pattern -> unit diff --git a/src/ligo/simplify/pascaligo.ml b/src/ligo/simplify/pascaligo.ml index 517eb83ef..7657d39f9 100644 --- a/src/ligo/simplify/pascaligo.ml +++ b/src/ligo/simplify/pascaligo.ml @@ -483,29 +483,52 @@ and simpl_cases : type a . (Raw.pattern * a) list -> a matching result = fun t - let open Raw in let get_var (t:Raw.pattern) = match t with | PVar v -> ok v.value - | _ -> simple_fail "not a var" + | _ -> + let error = + let title () = "not a var" in + let content () = Format.asprintf "%a" (PP_helpers.printer Parser.Pascaligo.ParserLog.print_pattern) t in + error title content + in + fail error in - let%bind _assert = - trace_strong (simple_error "only pattern with two cases supported now") @@ - Assert.assert_equal_int 2 (List.length t) in - let ((pa, ba), (pb, bb)) = List.(hd t, hd @@ tl t) in - let uncons p = match p with - | PCons {value = (hd, _)} -> ok hd - | _ -> simple_fail "uncons fail" in - let%bind (pa, pb) = bind_map_pair uncons (pa, pb) in - match (pa, ba), (pb, bb) with - | (PFalse _, f), (PTrue _, t) - | (PTrue _, t), (PFalse _, f) -> ok @@ Match_bool {match_true = t ; match_false = f} - | (PSome v, some), (PNone _, none) - | (PNone _, none), (PSome v, some) -> ( + let get_tuple (t:Raw.pattern) = match t with + | PCons v -> npseq_to_list v.value + | PTuple v -> npseq_to_list v.value.inside + | x -> [ x ] + in + let get_single (t:Raw.pattern) = + let t' = get_tuple t in + let%bind () = + trace_strong (simple_error "not single") @@ + Assert.assert_list_size t' 1 in + ok (List.hd t') in + let get_constr (t:Raw.pattern) = match t with + | PConstr v -> + let%bind var = get_single (snd v.value).value >>? get_var in + ok ((fst v.value).value , var) + | _ -> simple_fail "not a constr" + in + let%bind patterns = + let aux (x , y) = + let xs = get_tuple x in + trace_strong (simple_error "no tuple in patterns yet") @@ + Assert.assert_list_size xs 1 >>? fun () -> + ok (List.hd xs , y) + in + bind_map_list aux t in + match patterns with + | [(PFalse _ , f) ; (PTrue _ , t)] + | [(PTrue _ , t) ; (PFalse _ , f)] -> ok @@ Match_bool {match_true = t ; match_false = f} + | [(PSome v , some) ; (PNone _ , none)] + | [(PNone _ , none) ; (PSome v , some)] -> ( let (_, v) = v.value in let%bind v = match v.value.inside with | PVar v -> ok v.value | _ -> simple_fail "complex none patterns not supported yet" in ok @@ Match_option {match_none = none ; match_some = (v, some) } ) - | (PCons c, cons), (PList (PNil _), nil) - | (PList (PNil _), nil), (PCons c, cons) -> + | [(PCons c , cons) ; (PList (PNil _) , nil)] + | [(PList (PNil _) , nil) ; (PCons c, cons)] -> let%bind (a, b) = match c.value with | a, [(_, b)] -> @@ -515,9 +538,21 @@ and simpl_cases : type a . (Raw.pattern * a) list -> a matching result = fun t - | _ -> simple_fail "complex list patterns not supported yet" in ok @@ Match_list {match_cons = (a, b, cons) ; match_nil = nil} - | _ -> - let error () = simple_error "multi-level patterns not supported yet" () in - fail error + | lst -> + trace (simple_error "weird patterns not supported yet") @@ + let%bind constrs = + let aux (x , y) = + let error = + let title () = "Pattern" in + let content () = + Format.asprintf "Pattern : %a" (PP_helpers.printer Parser.Pascaligo.ParserLog.print_pattern) x in + error title content in + let%bind x' = + trace error @@ + get_constr x in + ok (x' , y) in + bind_map_list aux lst in + ok @@ Match_variant constrs and simpl_instruction_block : Raw.instruction -> block result = fun t -> match t with diff --git a/src/ligo/test/integration_tests.ml b/src/ligo/test/integration_tests.ml index b5c9ab633..70b0ee14a 100644 --- a/src/ligo/test/integration_tests.ml +++ b/src/ligo/test/integration_tests.ml @@ -14,6 +14,20 @@ let complex_function () : unit result = let make_expect = fun n -> (3 * n + 2) in expect_n_int program "main" make_expect +let variant () : unit result = + let%bind program = type_file "./contracts/variant.ligo" in + let%bind () = + let expected = e_a_constructor "Foo" (e_a_int 42) in + expect_evaluate program "foo" expected in + let%bind () = + let expected = e_a_constructor "Bar" (e_a_bool true) in + expect_evaluate program "bar" expected in + (* let%bind () = + * let make_expect = fun n -> (3 * n + 2) in + * expect_n_int program "fb" make_expect + * in *) + ok () + let closure () : unit result = let%bind program = type_file "./contracts/closure.ligo" in let%bind () = @@ -257,12 +271,29 @@ let matching () : unit result = let input = match n with | Some s -> e_a_some (e_a_int s) | None -> e_a_none t_int in - let expected = e_a_int 23 in + let expected = e_a_int (match n with + | Some s -> s + | None -> 23) in + trace (simple_error (Format.asprintf "on input %a" PP_helpers.(option int) n)) @@ expect program "match_option" input expected in bind_iter_list aux [Some 0 ; Some 2 ; Some 42 ; Some 163 ; Some (-1) ; None] in + let%bind () = + let aux n = + let input = match n with + | Some s -> e_a_some (e_a_int s) + | None -> e_a_none t_int in + let expected = e_a_int (match n with + | Some s -> s + | None -> 42) in + trace (simple_error (Format.asprintf "on input %a" PP_helpers.(option int) n)) @@ + expect program "match_expr_option" input expected + in + bind_iter_list aux + [Some 0 ; Some 2 ; Some 42 ; Some 163 ; Some (-1) ; None] + in ok () let declarations () : unit result = @@ -292,6 +323,7 @@ let counter_contract () : unit result = let main = "Integration (End to End)", [ test "function" function_ ; test "complex function" complex_function ; + test "variant" variant ; test "closure" closure ; test "shared function" shared_function ; test "shadow" shadow ; diff --git a/src/ligo/test/test_helpers.ml b/src/ligo/test/test_helpers.ml index bb266e469..3012790b3 100644 --- a/src/ligo/test/test_helpers.ml +++ b/src/ligo/test/test_helpers.ml @@ -32,10 +32,12 @@ let expect_evaluate program entry_point expected = Ast_simplified.assert_value_eq (expected , result) let expect_n_aux lst program entry_point make_input make_expected = + Format.printf "expect_n aux\n%!" ; let aux n = let input = make_input n in let expected = make_expected n in - expect program entry_point input expected + let result = expect program entry_point input expected in + result in let%bind _ = bind_map_list aux lst in ok () diff --git a/src/ligo/transpiler/transpiler.ml b/src/ligo/transpiler/transpiler.ml index 2a0c63365..f95cd48b9 100644 --- a/src/ligo/transpiler/transpiler.ml +++ b/src/ligo/transpiler/transpiler.ml @@ -70,7 +70,7 @@ let tuple_access_to_lr : type_value -> type_value list -> int -> (type_value * [ let lr_path = List.map (fun b -> if b then `Right else `Left) path in let%bind (_ , lst) = let aux = fun (ty , acc) cur -> - let%bind (a , b) = get_t_pair ty in + let%bind (a , b) = Mini_c.get_t_pair ty in match cur with | `Left -> ok (a , (a , `Left) :: acc) | `Right -> ok (b , (b , `Right) :: acc) in @@ -89,10 +89,10 @@ let record_access_to_lr : type_value -> type_value AST.type_name_map -> string - let node a b : (type_value * (type_value * [`Left | `Right]) list) result = match%bind bind_lr (a, b) with | `Left (t, acc) -> - let%bind (a, _) = get_t_pair t in + let%bind (a, _) = Mini_c.get_t_pair t in ok @@ (t, (a, `Left) :: acc) | `Right (t, acc) -> ( - let%bind (_, b) = get_t_pair t in + let%bind (_, b) = Mini_c.get_t_pair t in ok @@ (t, (b, `Right) :: acc) ) in let error_content () = @@ -195,6 +195,10 @@ and transpile_environment : AST.full_environment -> Environment.t result = fun x let%bind nlst = bind_map_ne_list transpile_small_environment x in ok @@ List.Ne.to_list nlst +and tree_of_sum : AST.type_value -> (type_name * AST.type_value) Append_tree.t result = fun t -> + let%bind map_tv = get_t_sum t in + ok @@ Append_tree.of_list @@ kv_list_of_map map_tv + and translate_annotated_expression (env:Environment.t) (ae:AST.annotated_expression) : expression result = let%bind tv = translate_type ae.type_annotation in let return ?(tv = tv) expr = @@ -213,10 +217,9 @@ and translate_annotated_expression (env:Environment.t) (ae:AST.annotated_express let%bind b = translate_annotated_expression env b in return @@ E_application (a, b) | E_constructor (m, param) -> - let%bind param' = translate_annotated_expression env ae in + let%bind param' = translate_annotated_expression env param in let (param'_expr , param'_tv) = Combinators.Expression.(get_content param' , get_type param') in - let%bind map_tv = get_t_sum ae.type_annotation in - let node_tv = Append_tree.of_list @@ kv_list_of_map map_tv in + let%bind node_tv = tree_of_sum ae.type_annotation in let leaf (k, tv) : (expression' option * type_value) result = if k = m then ( let%bind _ = @@ -297,11 +300,11 @@ and translate_annotated_expression (env:Environment.t) (ae:AST.annotated_express let node (a:expression result) b : expression result = match%bind bind_lr (a, b) with | `Left expr -> ( - let%bind (tv, _) = get_t_pair @@ Combinators.Expression.get_type expr in + let%bind (tv, _) = Mini_c.get_t_pair @@ Expression.get_type expr in return ~tv @@ E_constant ("CAR", [expr]) ) | `Right expr -> ( - let%bind (_, tv) = get_t_pair @@ Combinators.Expression.get_type expr in + let%bind (_, tv) = Mini_c.get_t_pair @@ Expression.get_type expr in return ~tv @@ E_constant ("CDR", [expr]) ) in let%bind expr = @@ -341,13 +344,74 @@ and translate_annotated_expression (env:Environment.t) (ae:AST.annotated_express | E_matching (expr, m) -> ( let%bind expr' = translate_annotated_expression env expr in match m with - | AST.Match_bool {match_true ; match_false} -> - let%bind (t, f) = bind_map_pair (translate_annotated_expression env) (match_true, match_false) in + | Match_bool {match_true ; match_false} -> + let%bind (t , f) = bind_map_pair (translate_annotated_expression env) (match_true, match_false) in return @@ E_Cond (expr', t, f) - | AST.Match_list _ | AST.Match_option _ | AST.Match_tuple (_, _) -> - simple_fail "only match bool exprs are translated yet" + | Match_option { match_none; match_some = ((name, tv), s) } -> + let%bind n = translate_annotated_expression env match_none in + let%bind (tv' , s') = + let%bind tv' = translate_type tv in + let env' = Environment.(add (name , tv') @@ extend env) in + let%bind s' = translate_annotated_expression env' s in + ok (tv' , s') in + return @@ E_if_none (expr' , n , ((name , tv') , s')) + | Match_variant (lst , variant) -> ( + let%bind tree = tree_of_sum variant in + let%bind tree' = match tree with + | Empty -> simple_fail "match empty variant" + | Full x -> ok x in + let%bind tree'' = + let rec aux t = + match (t : _ Append_tree.t') with + | Leaf (name , tv) -> + let%bind tv' = translate_type tv in + ok (`Leaf name , tv') + | Node {a ; b} -> + let%bind a' = aux a in + let%bind b' = aux b in + let tv' = Mini_c.t_union (snd a') (snd b') in + ok (`Node (a' , b') , tv') + in aux tree' + in + + let rec aux acc t = + let top = + match acc with + | None -> expr' + | Some x -> x in + match t with + | ((`Leaf constructor_name) , tv) -> ( + let%bind ((_ , name) , body) = + trace_option (simple_error "not supposed to happen here: missing match clause") @@ + List.find_opt (fun ((constructor_name' , _) , _) -> constructor_name' = constructor_name) lst in + let env' = Environment.(add (name , tv) @@ extend env) in + let%bind body' = translate_annotated_expression env' body in + return @@ E_let_in ((name , tv) , top , body') + ) + | ((`Node (a , b)) , tv) -> + let%bind a' = + let%bind a_ty = get_t_left tv in + let a_var = "left" , a_ty in + let env' = Environment.(add a_var @@ extend env) in + let%bind e = aux (Some (Expression.make (E_variable "left") a_ty env')) a in + ok (a_var , e) + in + let%bind b' = + let%bind b_ty = get_t_right tv in + let b_var = "right" , b_ty in + let env' = Environment.(add b_var @@ extend env) in + let%bind e = aux (Some (Expression.make (E_variable "right") b_ty env')) b in + ok (b_var , e) + in + return @@ E_if_left (top , a' , b') + in + aux None tree'' + ) + | AST.Match_list _ | AST.Match_tuple (_, _) -> + simple_fail "only match bool and option exprs are translated yet" ) + and translate_lambda_shallow : Mini_c.Environment.t -> AST.lambda -> Mini_c.expression result = fun env l -> let { binder ; input_type ; output_type ; body ; result } : AST.lambda = l in (* Shallow capture. Capture the whole environment. Extend it with a new scope. Append it the input. *) @@ -448,8 +512,10 @@ let translate_entry (lst:AST.program) (name:string) : anon_function result = @@ aux [] lst in ok (lst', l, tv) in let l' = {l with body = lst' @ l.body} in - trace (simple_error "translating entry") - @@ translate_main l' tv + let r = + trace (simple_error "translating entry") @@ + translate_main l' tv in + r open Combinators diff --git a/src/ligo/typer/typer.ml b/src/ligo/typer/typer.ml index 1206cf8c8..290c421e1 100644 --- a/src/ligo/typer/typer.ml +++ b/src/ligo/typer/typer.ml @@ -189,6 +189,50 @@ and type_match : type i o . (environment -> i -> o result) -> environment -> O.t let e' = List.fold_left aux e lst' in let%bind b' = f e' b in ok (O.Match_tuple (lst, b')) + | Match_variant lst -> + let%bind variant_opt = + let aux acc ((constructor_name , _) , _) = + let%bind (_ , variant) = + trace_option (simple_error "bad constructor") @@ + Environment.get_constructor constructor_name e in + let%bind acc = match acc with + | None -> ok (Some variant) + | Some variant' -> ( + Ast_typed.assert_type_value_eq (variant , variant') >>? fun () -> + ok (Some variant) + ) in + ok acc in + trace (simple_error "in match variant") @@ + bind_fold_list aux None lst in + let%bind variant = + trace_option (simple_error "empty variant") @@ + variant_opt in + let%bind () = + let%bind variant_cases' = Ast_typed.Combinators.get_t_sum variant in + let variant_cases = List.map fst @@ Map.String.to_kv_list variant_cases' in + let match_cases = List.map (Function.compose fst fst) lst in + let test_case = fun c -> + Assert.assert_true (List.mem c match_cases) + in + let%bind () = + trace (simple_error "missing case match") @@ + bind_iter_list test_case variant_cases in + let%bind () = + trace_strong (simple_error "redundant case match") @@ + Assert.assert_true List.(length variant_cases = length match_cases) in + ok () + in + let%bind lst' = + let aux ((constructor_name , name) , b) = + let%bind (constructor , _) = + trace_option (simple_error "bad constructor??") @@ + Environment.get_constructor constructor_name e in + let e' = Environment.add_ez name constructor e in + let%bind b' = f e' b in + ok ((constructor_name , name) , b') + in + bind_map_list aux lst in + ok (O.Match_variant (lst' , variant)) and evaluate_type (e:environment) (t:I.type_expression) : O.type_value result = let return tv' = ok (make_t tv' (Some t)) in @@ -387,12 +431,26 @@ and type_annotated_expression : environment -> I.annotated_expression -> O.annot | E_matching (ex, m) -> ( let%bind ex' = type_annotated_expression e ex in let%bind m' = type_match type_annotated_expression e ex'.type_annotation m in - let%bind tv = match m' with - | Match_bool {match_true ; match_false} -> - let%bind _ = O.assert_type_value_eq (match_true.type_annotation, match_false.type_annotation) in - ok match_true.type_annotation - | _ -> simple_fail "can only type match_bool expressions yet" in - return (E_matching (ex' , m')) tv + let tvs = + let aux (cur:O.value O.matching) = + match cur with + | Match_bool { match_true ; match_false } -> [ match_true ; match_false ] + | Match_list { match_nil ; match_cons = (_ , _ , match_cons) } -> [ match_nil ; match_cons ] + | Match_option { match_none ; match_some = (_ , match_some) } -> [ match_none ; match_some ] + | Match_tuple (_ , match_tuple) -> [ match_tuple ] + | Match_variant (lst , _) -> List.map snd lst in + List.map get_type_annotation @@ aux m' in + let aux prec cur = + let%bind () = + match prec with + | None -> ok () + | Some cur' -> Ast_typed.assert_type_value_eq (cur , cur') in + ok (Some cur) in + let%bind tv_opt = bind_fold_list aux None tvs in + let%bind tv = + trace_option (simple_error "empty matching") @@ + tv_opt in + return (O.E_matching (ex', m')) tv ) and type_constant (name:string) (lst:O.type_value list) (tv_opt:O.type_value option) : (string * O.type_value) result = @@ -551,3 +609,9 @@ and untype_matching : type o i . (o -> i result) -> o O.matching -> (i I.matchin let%bind cons = f cons in let match_cons = hd, tl, cons in ok @@ Match_list {match_nil ; match_cons} + | Match_variant (lst , _) -> + let aux ((a,b),c) = + let%bind c' = f c in + ok ((a,b),c') in + let%bind lst' = bind_map_list aux lst in + ok @@ Match_variant lst'