diff --git a/src/passes/2-simplify/cameligo.ml b/src/passes/2-simplify/cameligo.ml index 2725bf952..9b047239d 100644 --- a/src/passes/2-simplify/cameligo.ml +++ b/src/passes/2-simplify/cameligo.ml @@ -22,7 +22,14 @@ let get_value : 'a Raw.reg -> 'a = fun x -> x.value module Errors = struct let wrong_pattern expected_name actual = let title () = "wrong pattern" in - let message () = "" in + let message () = + match actual with + | Raw.PTuple _ -> "tuple" + | Raw.PRecord _ -> "record" + | Raw.PList _ -> "list" + | Raw.PBytes _ -> "bytes" + | _ -> "other" + in let data = [ ("expected", fun () -> expected_name); ("actual_loc" , fun () -> Format.asprintf "%a" Location.pp_lift @@ Raw.pattern_to_region actual) @@ -128,6 +135,12 @@ module Errors = struct fun () -> Format.asprintf "%a" Location.pp_lift @@ region) ] in error ~data title message + + let corner_case description = + let title () = "corner case" in + let message () = description in + error title message + end open Errors @@ -160,9 +173,16 @@ let rec expr_to_typed_expr : Raw.expr -> _ = function | EAnnot {value={inside=e,_,t; _}; _} -> ok (e, Some t) | e -> ok (e , None) -let patterns_to_var : Raw.pattern nseq -> _ = fun ps -> +let rec patterns_to_typed_vars : Raw.pattern nseq -> _ = fun ps -> match ps with - | pattern, [] -> pattern_to_var pattern + | pattern, [] -> + begin + match pattern with + | Raw.PPar pp -> patterns_to_typed_vars (pp.value.inside, []) + | Raw.PTuple pt -> bind_map_list pattern_to_typed_var (npseq_to_list pt.value) + | Raw.PVar _ -> bind_list [pattern_to_typed_var pattern] + | other -> (fail @@ wrong_pattern "parenthetical, tuple, or variable" other) + end | _ -> fail @@ multiple_patterns "let" (nseq_to_list ps) let rec simpl_type_expression : Raw.type_expr -> type_expression result = fun te -> @@ -254,16 +274,51 @@ let rec simpl_expression : Raw.ELetIn e -> let Raw.{binding; body; _} = e.value in let Raw.{binders; lhs_type; let_rhs; _} = binding in - let%bind variable = patterns_to_var binders in + let%bind variables = patterns_to_typed_vars binders in let%bind ty_opt = bind_map_option (fun (_,te) -> simpl_type_expression te) lhs_type in let%bind rhs = simpl_expression let_rhs in - let rhs' = + let rhs_b = Var.fresh ~name: "rhs" () in + let rhs',rhs_b_expr = match ty_opt with - None -> rhs - | Some ty -> e_annotation rhs ty in + None -> rhs, e_variable rhs_b + | Some ty -> (e_annotation rhs ty), e_annotation (e_variable rhs_b) ty in let%bind body = simpl_expression body in - return @@ e_let_in (Var.of_name variable.value , None) rhs' body + let prepare_variable (ty_var: Raw.variable * Raw.type_expr option) = + let variable, ty_opt = ty_var in + let var_expr = Var.of_name variable.value in + let%bind ty_expr_opt = + match ty_opt with + | Some ty -> bind_map_option simpl_type_expression (Some ty) + | None -> ok None + in ok (var_expr, ty_expr_opt) + in + let%bind prep_vars = bind_list (List.map prepare_variable variables) in + let%bind () = + if (List.length prep_vars) = 0 + then fail @@ corner_case "let ... in without variables passed parsing stage" + else ok () + in + let rhs_b_expr = (* We only want to evaluate the rhs first if multi-bind *) + if List.length prep_vars = 1 + then rhs' else rhs_b_expr + in + let rec chain_let_in variables body : expression = + match variables with + | hd :: [] -> + if (List.length prep_vars = 1) + then e_let_in hd rhs_b_expr body + else e_let_in hd (e_accessor rhs_b_expr [Access_tuple ((List.length prep_vars) - 1)]) body + | hd :: tl -> + e_let_in hd + (e_accessor rhs_b_expr [Access_tuple ((List.length prep_vars) - (List.length tl) - 1)]) + (chain_let_in tl body) + | [] -> body (* Precluded by corner case assertion above *) + in + if List.length prep_vars = 1 + then ok (chain_let_in prep_vars body) + (* Bind the right hand side so we only evaluate it once *) + else ok (e_let_in (rhs_b, ty_opt) rhs' (chain_let_in prep_vars body)) | Raw.EAnnot a -> let Raw.{inside=expr, _, type_expr; _}, loc = r_split a in let%bind expr' = simpl_expression expr in diff --git a/src/test/contracts/let_in_multi_bind.mligo b/src/test/contracts/let_in_multi_bind.mligo new file mode 100644 index 000000000..e61dc14a7 --- /dev/null +++ b/src/test/contracts/let_in_multi_bind.mligo @@ -0,0 +1,5 @@ +let sum (p: int * int) : int = + let i, result = p in i + result + +let sum2 (p: string * string * string * string) : int = + let a, b, c, d = p in a ^ b ^ c ^ d diff --git a/src/test/integration_tests.ml b/src/test/integration_tests.ml index 05a95463d..f6b58f237 100644 --- a/src/test/integration_tests.ml +++ b/src/test/integration_tests.ml @@ -1795,6 +1795,18 @@ let type_tuple_destruct () : unit result = let%bind () = expect_eq program "type_tuple_d_2" (e_unit ()) (e_string "helloworld") in ok () +let let_in_multi_bind () : unit result = + let%bind program = mtype_file "./contracts/let_in_multi_bind.mligo" in + let%bind () = expect_eq program "sum" (e_tuple [e_int 10; e_int 10]) (e_int 20) in + let%bind () = expect_eq program "sum2" + (e_tuple + [e_string "my" ; + e_string "name" ; + e_string "is" ; + e_string "bob" ]) + (e_string "mynameisbob") + in ok () + let main = test_suite "Integration (End to End)" [ test "key hash" key_hash ; test "chain id" chain_id ; @@ -1933,4 +1945,5 @@ let main = test_suite "Integration (End to End)" [ test "deep_access (ligo)" deep_access_ligo; test "entrypoints (ligo)" entrypoints_ligo ; test "type tuple destruct (mligo)" type_tuple_destruct ; + test "let in multi-bind (mligo)" let_in_multi_bind ; ]