diff --git a/src/passes/6-transpiler/transpiler.ml b/src/passes/6-transpiler/transpiler.ml index f1c61f7ba..11ff10988 100644 --- a/src/passes/6-transpiler/transpiler.ml +++ b/src/passes/6-transpiler/transpiler.ml @@ -407,6 +407,7 @@ and transpile_annotated_expression (ae:AST.annotated_expression) : expression re | ("LIST_MAP" , lst) -> map lst | ("MAP_MAP" , lst) -> map lst | ("LIST_FOLD" , lst) -> fold lst + | ("MAP_FOLD" , lst) -> fold lst | _ -> ( let%bind lst' = bind_map_list (transpile_annotated_expression) lst in return @@ E_constant (name , lst') diff --git a/src/passes/operators/operators.ml b/src/passes/operators/operators.ml index 927d16c6c..9ead3b7bd 100644 --- a/src/passes/operators/operators.ml +++ b/src/passes/operators/operators.ml @@ -86,6 +86,7 @@ module Simplify = struct ("list_map" , "LIST_MAP") ; ("map_iter" , "MAP_ITER") ; ("map_map" , "MAP_MAP") ; + ("map_fold" , "MAP_FOLD") ; ("sha_256" , "SHA256") ; ("sha_512" , "SHA512") ; ("blake2b" , "BLAKE2b") ; @@ -155,6 +156,7 @@ module Simplify = struct ("Map.remove" , "MAP_REMOVE") ; ("Map.iter" , "MAP_ITER") ; ("Map.map" , "MAP_MAP") ; + ("Map.fold" , "LIST_FOLD") ; ("String.length", "SIZE") ; ("String.size", "SIZE") ; @@ -285,16 +287,6 @@ module Typer = struct let%bind () = assert_eq_1 arg (t_pair k v ()) in ok @@ t_map k res () - let map_fold : typer = typer_2 "MAP_FOLD" @@ fun f m -> - let%bind (k, v) = get_t_map m in - let%bind (arg_1 , res) = get_t_function f in - let%bind (arg_2 , res') = get_t_function res in - let%bind (arg_3 , res'') = get_t_function res' in - let%bind () = assert_eq_1 arg_1 k in - let%bind () = assert_eq_1 arg_2 v in - let%bind () = assert_eq_1 arg_3 res'' in - ok @@ res' - let size = typer_1 "SIZE" @@ fun t -> let%bind () = Assert.assert_true @@ @@ -504,6 +496,20 @@ module Typer = struct let%bind () = assert_eq_1 ~msg:"res init" res init in ok res + let map_fold = typer_3 "MAP_FOLD" @@ fun map init body -> + let%bind (arg , res) = get_t_function body in + let%bind (prec , cur) = get_t_pair arg in + let%bind (key , value) = get_t_map map in + let msg = Format.asprintf "%a vs %a" + Ast_typed.PP.type_value key + Ast_typed.PP.type_value arg + in + trace (simple_error ("bad list fold:" ^ msg)) @@ + let%bind () = assert_eq_1 ~msg:"key cur" (t_pair key value ()) cur in + let%bind () = assert_eq_1 ~msg:"prec res" prec res in + let%bind () = assert_eq_1 ~msg:"res init" res init in + ok res + let not_ = typer_1 "NOT" @@ fun elt -> if eq_1 elt (t_bool ()) then ok @@ t_bool () diff --git a/src/test/contracts/map.ligo b/src/test/contracts/map.ligo index af3697768..722412603 100644 --- a/src/test/contracts/map.ligo +++ b/src/test/contracts/map.ligo @@ -44,3 +44,7 @@ function iter_op (const m : foobar) : int is function map_op (const m : foobar) : foobar is function increment (const i : int ; const j : int) : int is block { skip } with j + 1 ; block { skip } with map_map(m , increment) ; + +function fold_op (const m : foobar) : int is + function aggregate (const i : int ; const j : (int * int)) : int is block { skip } with i + j.0 + j.1 ; + block { skip } with map_fold(m , 10 , aggregate) diff --git a/src/test/integration_tests.ml b/src/test/integration_tests.ml index 4e280647e..eab8395da 100644 --- a/src/test/integration_tests.ml +++ b/src/test/integration_tests.ml @@ -400,6 +400,11 @@ let map () : unit result = let expected = e_int 66 in expect_eq program "iter_op" input expected in + let%bind () = + let input = ez [(1 , 10) ; (2 , 20) ; (3 , 30) ] in + let expected = e_int 76 in + expect_eq program "fold_op" input expected + in let%bind () = let input = ez [(1 , 10) ; (2 , 20) ; (3 , 30) ] in let expected = ez [(1 , 11) ; (2 , 21) ; (3 , 31) ] in @@ -676,19 +681,19 @@ let mligo_list () : unit result = let%bind program = mtype_file "./contracts/list.mligo" in let aux lst = e_list @@ List.map e_int lst in let%bind () = expect_eq program "fold_op" (aux [ 1 ; 2 ; 3 ]) (e_int 16) in - (* let%bind () = - * let make_input n = - * e_pair (e_list [e_int n; e_int (2*n)]) - * (e_pair (e_int 3) (e_list [e_int 8])) in - * let make_expected n = - * e_pair (e_typed_list [] t_operation) - * (e_pair (e_int (n+3)) (e_list [e_int (2*n)])) - * in - * expect_eq_n program "main" make_input make_expected - * in - * let%bind () = expect_eq_evaluate program "x" (e_list []) in - * let%bind () = expect_eq_evaluate program "y" (e_list @@ List.map e_int [3 ; 4 ; 5]) in - * let%bind () = expect_eq_evaluate program "z" (e_list @@ List.map e_int [2 ; 3 ; 4 ; 5]) in *) + let%bind () = + let make_input n = + e_pair (e_list [e_int n; e_int (2*n)]) + (e_pair (e_int 3) (e_list [e_int 8])) in + let make_expected n = + e_pair (e_typed_list [] t_operation) + (e_pair (e_int (n+3)) (e_list [e_int (2*n)])) + in + expect_eq_n program "main" make_input make_expected + in + let%bind () = expect_eq_evaluate program "x" (e_list []) in + let%bind () = expect_eq_evaluate program "y" (e_list @@ List.map e_int [3 ; 4 ; 5]) in + let%bind () = expect_eq_evaluate program "z" (e_list @@ List.map e_int [2 ; 3 ; 4 ; 5]) in ok () let lambda_mligo () : unit result =