diff --git a/src/passes/6-transpiler/transpiler.ml b/src/passes/6-transpiler/transpiler.ml index 11ff10988..fc71afebb 100644 --- a/src/passes/6-transpiler/transpiler.ml +++ b/src/passes/6-transpiler/transpiler.ml @@ -407,6 +407,7 @@ and transpile_annotated_expression (ae:AST.annotated_expression) : expression re | ("LIST_MAP" , lst) -> map lst | ("MAP_MAP" , lst) -> map lst | ("LIST_FOLD" , lst) -> fold lst + | ("SET_FOLD" , lst) -> fold lst | ("MAP_FOLD" , lst) -> fold lst | _ -> ( let%bind lst' = bind_map_list (transpile_annotated_expression) lst in diff --git a/src/passes/operators/operators.ml b/src/passes/operators/operators.ml index 9ead3b7bd..2dc5ef7d6 100644 --- a/src/passes/operators/operators.ml +++ b/src/passes/operators/operators.ml @@ -81,6 +81,7 @@ module Simplify = struct ("set_add" , "SET_ADD") ; ("set_remove" , "SET_REMOVE") ; ("set_iter" , "SET_ITER") ; + ("set_fold" , "SET_FOLD") ; ("list_iter" , "LIST_ITER") ; ("list_fold" , "LIST_FOLD") ; ("list_map" , "LIST_MAP") ; @@ -148,6 +149,7 @@ module Simplify = struct ("Set.empty" , "SET_EMPTY") ; ("Set.add" , "SET_ADD") ; ("Set.remove" , "SET_REMOVE") ; + ("Set.fold" , "SET_FOLD") ; ("Map.find_opt" , "MAP_FIND_OPT") ; ("Map.find" , "MAP_FIND") ; @@ -156,7 +158,7 @@ module Simplify = struct ("Map.remove" , "MAP_REMOVE") ; ("Map.iter" , "MAP_ITER") ; ("Map.map" , "MAP_MAP") ; - ("Map.fold" , "LIST_FOLD") ; + ("Map.fold" , "MAP_FOLD") ; ("String.length", "SIZE") ; ("String.size", "SIZE") ; @@ -496,6 +498,20 @@ module Typer = struct let%bind () = assert_eq_1 ~msg:"res init" res init in ok res + let set_fold = typer_3 "SET_FOLD" @@ fun lst init body -> + let%bind (arg , res) = get_t_function body in + let%bind (prec , cur) = get_t_pair arg in + let%bind key = get_t_set lst in + let msg = Format.asprintf "%a vs %a" + Ast_typed.PP.type_value key + Ast_typed.PP.type_value arg + in + trace (simple_error ("bad set fold:" ^ msg)) @@ + let%bind () = assert_eq_1 ~msg:"key cur" key cur in + let%bind () = assert_eq_1 ~msg:"prec res" prec res in + let%bind () = assert_eq_1 ~msg:"res init" res init in + ok res + let map_fold = typer_3 "MAP_FOLD" @@ fun map init body -> let%bind (arg , res) = get_t_function body in let%bind (prec , cur) = get_t_pair arg in @@ -593,6 +609,7 @@ module Typer = struct set_add ; set_remove ; set_iter ; + set_fold ; list_iter ; list_map ; list_fold ; diff --git a/src/test/contracts/set_arithmetic-1.ligo b/src/test/contracts/set_arithmetic-1.ligo index 0cfab61d2..f5d332687 100644 --- a/src/test/contracts/set_arithmetic-1.ligo +++ b/src/test/contracts/set_arithmetic-1.ligo @@ -9,3 +9,8 @@ function iter_op (const s : set(int)) : int is begin set_iter(s , aggregate) ; end with r + +function fold_op (const s : set(int)) : int is + function aggregate (const i : int ; const j : int) : int is + block { skip } with i + j + block { skip } with set_fold(s , 15 , aggregate) diff --git a/src/test/integration_tests.ml b/src/test/integration_tests.ml index eab8395da..85e02d22d 100644 --- a/src/test/integration_tests.ml +++ b/src/test/integration_tests.ml @@ -224,6 +224,11 @@ let set_arithmetic () : unit result = expect_eq program "mem_op" (e_set [e_string "foo" ; e_string "bar"]) (e_bool false) in + let%bind () = + expect_eq program_1 "fold_op" + (e_set [ e_int 4 ; e_int 10 ]) + (e_int 29) + in ok () let unit_expression () : unit result =