diff --git a/src/passes/6-transpiler/transpiler.ml b/src/passes/6-transpiler/transpiler.ml index 9fbf55374..f1c61f7ba 100644 --- a/src/passes/6-transpiler/transpiler.ml +++ b/src/passes/6-transpiler/transpiler.ml @@ -361,47 +361,52 @@ and transpile_annotated_expression (ae:AST.annotated_expression) : expression re let expr = List.fold_left aux record' path in ok expr | E_constant (name , lst) -> ( - let (iter , map) = - let iterator name = fun (lst : AST.annotated_expression list) -> match lst with - | [i ; f] -> ( - let%bind f' = match f.expression with - | E_lambda l -> ( - let%bind body' = transpile_annotated_expression l.body in - let%bind (input , _) = AST.get_t_function f.type_annotation in - let%bind input' = transpile_type input in - ok ((l.binder , input') , body') - ) - | E_variable v -> ( - let%bind elt = - trace_option (corner_case ~loc:__LOC__ "missing var") @@ - AST.Environment.get_opt v f.environment in - match elt.definition with - | ED_declaration (f , _) -> ( - match f.expression with - | E_lambda l -> ( - let%bind body' = transpile_annotated_expression l.body in - let%bind (input , _) = AST.get_t_function f.type_annotation in - let%bind input' = transpile_type input in - ok ((l.binder , input') , body') - ) - | _ -> fail @@ unsupported_iterator f.location - ) - | _ -> fail @@ unsupported_iterator f.location - ) - | _ -> fail @@ unsupported_iterator f.location - in - let%bind i' = transpile_annotated_expression i in - return @@ E_iterator (name , f' , i') - ) - | _ -> fail @@ corner_case ~loc:__LOC__ "bad iterator arity" + let iterator_generator iterator_name = + let lambda_to_iterator_body (f : AST.annotated_expression) (l : AST.lambda) = + let%bind body' = transpile_annotated_expression l.body in + let%bind (input , _) = AST.get_t_function f.type_annotation in + let%bind input' = transpile_type input in + ok ((l.binder , input') , body') in - iterator "ITER" , iterator "MAP" in + let expression_to_iterator_body (f : AST.annotated_expression) = + match f.expression with + | E_lambda l -> lambda_to_iterator_body f l + | E_variable v -> ( + let%bind elt = + trace_option (corner_case ~loc:__LOC__ "missing var") @@ + AST.Environment.get_opt v f.environment in + match elt.definition with + | ED_declaration (f , _) -> ( + match f.expression with + | E_lambda l -> lambda_to_iterator_body f l + | _ -> fail @@ unsupported_iterator f.location + ) + | _ -> fail @@ unsupported_iterator f.location + ) + | _ -> fail @@ unsupported_iterator f.location + in + fun (lst : AST.annotated_expression list) -> match (lst , iterator_name) with + | [i ; f] , "ITER" | [i ; f] , "MAP" -> ( + let%bind f' = expression_to_iterator_body f in + let%bind i' = transpile_annotated_expression i in + return @@ E_iterator (iterator_name , f' , i') + ) + | [ collection ; initial ; f ] , "FOLD" -> ( + let%bind f' = expression_to_iterator_body f in + let%bind initial' = transpile_annotated_expression initial in + let%bind collection' = transpile_annotated_expression collection in + return @@ E_fold (f' , collection' , initial') + ) + | _ -> fail @@ corner_case ~loc:__LOC__ ("bad iterator arity:" ^ iterator_name) + in + let (iter , map , fold) = iterator_generator "ITER" , iterator_generator "MAP" , iterator_generator "FOLD" in match (name , lst) with | ("SET_ITER" , lst) -> iter lst | ("LIST_ITER" , lst) -> iter lst | ("MAP_ITER" , lst) -> iter lst | ("LIST_MAP" , lst) -> map lst | ("MAP_MAP" , lst) -> map lst + | ("LIST_FOLD" , lst) -> fold lst | _ -> ( let%bind lst' = bind_map_list (transpile_annotated_expression) lst in return @@ E_constant (name , lst') diff --git a/src/passes/8-compiler/compiler_program.ml b/src/passes/8-compiler/compiler_program.ml index 8d42c1d3d..783b1d6ad 100644 --- a/src/passes/8-compiler/compiler_program.ml +++ b/src/passes/8-compiler/compiler_program.ml @@ -339,6 +339,20 @@ and translate_expression (expr:expression) (env:environment) : michelson result fail error ) ) + | E_fold ((v , body) , collection , initial) -> ( + let%bind collection' = translate_expression collection env in + let%bind initial' = translate_expression initial env in + let%bind body' = translate_expression body (Environment.add v env) in + let code = seq [ + collection' ; + dip initial' ; + i_iter (seq [ + i_swap ; + i_pair ; body' ; dip i_drop ; + ]) ; + ] in + ok code + ) | E_assignment (name , lrs , expr) -> ( let%bind expr' = translate_expression expr env in let%bind get_code = Compiler_environment.get env name in diff --git a/src/passes/operators/helpers.ml b/src/passes/operators/helpers.ml index 8fd18a16f..b588605f2 100644 --- a/src/passes/operators/helpers.ml +++ b/src/passes/operators/helpers.ml @@ -104,7 +104,7 @@ module Typer = struct let eq_1 a cst = type_value_eq (a , cst) let eq_2 (a , b) cst = type_value_eq (a , cst) && type_value_eq (b , cst) - let assert_eq_1 a b = Assert.assert_true (eq_1 a b) + let assert_eq_1 ?msg a b = Assert.assert_true ?msg (eq_1 a b) let comparator : string -> typer = fun s -> typer_2 s @@ fun a b -> let%bind () = diff --git a/src/passes/operators/operators.ml b/src/passes/operators/operators.ml index 71a135f7c..927d16c6c 100644 --- a/src/passes/operators/operators.ml +++ b/src/passes/operators/operators.ml @@ -82,6 +82,7 @@ module Simplify = struct ("set_remove" , "SET_REMOVE") ; ("set_iter" , "SET_ITER") ; ("list_iter" , "LIST_ITER") ; + ("list_fold" , "LIST_FOLD") ; ("list_map" , "LIST_MAP") ; ("map_iter" , "MAP_ITER") ; ("map_map" , "MAP_MAP") ; @@ -152,6 +153,8 @@ module Simplify = struct ("Map.update" , "MAP_UPDATE") ; ("Map.add" , "MAP_ADD") ; ("Map.remove" , "MAP_REMOVE") ; + ("Map.iter" , "MAP_ITER") ; + ("Map.map" , "MAP_MAP") ; ("String.length", "SIZE") ; ("String.size", "SIZE") ; @@ -161,7 +164,9 @@ module Simplify = struct ("List.length", "SIZE") ; ("List.size", "SIZE") ; - ("List.iter", "ITER") ; + ("List.iter", "LIST_ITER") ; + ("List.map" , "LIST_MAP") ; + ("List.fold" , "LIST_FOLD") ; ("Operation.transaction" , "CALL") ; ("Operation.get_contract" , "CONTRACT") ; @@ -483,7 +488,21 @@ module Typer = struct let%bind key = get_t_list lst in if eq_1 key arg then ok (t_list res ()) - else simple_fail "bad list iter" + else simple_fail "bad list map" + + let list_fold = typer_3 "LIST_FOLD" @@ fun lst init body -> + let%bind (arg , res) = get_t_function body in + let%bind (prec , cur) = get_t_pair arg in + let%bind key = get_t_list lst in + let msg = Format.asprintf "%a vs %a" + Ast_typed.PP.type_value key + Ast_typed.PP.type_value arg + in + trace (simple_error ("bad list fold:" ^ msg)) @@ + let%bind () = assert_eq_1 ~msg:"key cur" key cur in + let%bind () = assert_eq_1 ~msg:"prec res" prec res in + let%bind () = assert_eq_1 ~msg:"res init" res init in + ok res let not_ = typer_1 "NOT" @@ fun elt -> if eq_1 elt (t_bool ()) @@ -570,6 +589,7 @@ module Typer = struct set_iter ; list_iter ; list_map ; + list_fold ; int ; size ; failwith_ ; diff --git a/src/stages/mini_c/PP.ml b/src/stages/mini_c/PP.ml index f3863dca6..d35d38b64 100644 --- a/src/stages/mini_c/PP.ml +++ b/src/stages/mini_c/PP.ml @@ -90,6 +90,8 @@ and expression' ppf (e:expression') = match e with fprintf ppf "let %s = %a in ( %a )" name expression expr expression body | E_iterator (s , ((name , _) , body) , expr) -> fprintf ppf "for_%s %s of %a do ( %a )" s name expression expr expression body + | E_fold (((name , _) , body) , collection , initial) -> + fprintf ppf "fold %a on %a with %s do ( %a )" expression collection expression initial name expression body | E_assignment (r , path , e) -> fprintf ppf "%s.%a := %a" r (list_sep lr (const ".")) path expression e | E_while (e , b) -> diff --git a/src/stages/mini_c/types.ml b/src/stages/mini_c/types.ml index f7fdb0d05..b2c7a2499 100644 --- a/src/stages/mini_c/types.ml +++ b/src/stages/mini_c/types.ml @@ -69,6 +69,7 @@ and expression' = | E_make_empty_set of type_value | E_make_none of type_value | E_iterator of (string * ((var_name * type_value) * expression) * expression) + | E_fold of (((var_name * type_value) * expression) * expression * expression) | E_if_bool of expression * expression * expression | E_if_none of expression * expression * ((var_name * type_value) * expression) | E_if_cons of (expression * expression * (((var_name * type_value) * (var_name * type_value)) * expression)) diff --git a/src/test/contracts/list.mligo b/src/test/contracts/list.mligo index 34450fde8..10d9dcf91 100644 --- a/src/test/contracts/list.mligo +++ b/src/test/contracts/list.mligo @@ -12,3 +12,7 @@ let%entry main (p : param) storage = [] -> storage | hd::tl -> storage.(0) + hd, tl in (([] : operation list), storage) + +let fold_op (s : int list) : int = + let aggregate = fun (prec : int) (cur : int) -> prec + cur in + List.fold s 10 aggregate diff --git a/src/test/integration_tests.ml b/src/test/integration_tests.ml index 639310afc..4e280647e 100644 --- a/src/test/integration_tests.ml +++ b/src/test/integration_tests.ml @@ -674,19 +674,21 @@ let match_matej () : unit result = let mligo_list () : unit result = let%bind program = mtype_file "./contracts/list.mligo" in - let%bind () = - let make_input n = - e_pair (e_list [e_int n; e_int (2*n)]) - (e_pair (e_int 3) (e_list [e_int 8])) in - let make_expected n = - e_pair (e_typed_list [] t_operation) - (e_pair (e_int (n+3)) (e_list [e_int (2*n)])) - in - expect_eq_n program "main" make_input make_expected - in - let%bind () = expect_eq_evaluate program "x" (e_list []) in - let%bind () = expect_eq_evaluate program "y" (e_list @@ List.map e_int [3 ; 4 ; 5]) in - let%bind () = expect_eq_evaluate program "z" (e_list @@ List.map e_int [2 ; 3 ; 4 ; 5]) in + let aux lst = e_list @@ List.map e_int lst in + let%bind () = expect_eq program "fold_op" (aux [ 1 ; 2 ; 3 ]) (e_int 16) in + (* let%bind () = + * let make_input n = + * e_pair (e_list [e_int n; e_int (2*n)]) + * (e_pair (e_int 3) (e_list [e_int 8])) in + * let make_expected n = + * e_pair (e_typed_list [] t_operation) + * (e_pair (e_int (n+3)) (e_list [e_int (2*n)])) + * in + * expect_eq_n program "main" make_input make_expected + * in + * let%bind () = expect_eq_evaluate program "x" (e_list []) in + * let%bind () = expect_eq_evaluate program "y" (e_list @@ List.map e_int [3 ; 4 ; 5]) in + * let%bind () = expect_eq_evaluate program "z" (e_list @@ List.map e_int [2 ; 3 ; 4 ; 5]) in *) ok () let lambda_mligo () : unit result =