diff --git a/gitlab-pages/docs/language-basics/sets-lists-touples.md b/gitlab-pages/docs/language-basics/sets-lists-touples.md index ec1bfa43f..707887eb1 100644 --- a/gitlab-pages/docs/language-basics/sets-lists-touples.md +++ b/gitlab-pages/docs/language-basics/sets-lists-touples.md @@ -102,13 +102,13 @@ let smaller_set: int_set = Set.remove 3 my_set ```pascaligo function sum(const result: int; const i: int): int is result + i; // Outputs 6 -const sum_of_a_set: int = set_fold(my_set, 0, sum); +const sum_of_a_set: int = set_fold(sum, my_set, 0); ``` ```cameligo let sum (result: int) (i: int) : int = result + i -let sum_of_a_set: int = Set.fold my_set 0 sum +let sum_of_a_set: int = Set.fold sum my_set 0 ``` @@ -168,7 +168,7 @@ let larger_list: int_list = 4 :: my_list ```pascaligo function increment(const i: int): int is block { skip } with i + 1; // Creates a new list with elements incremented by 1 -const incremented_list: int_list = list_map(even_larger_list, increment); +const incremented_list: int_list = list_map(increment, even_larger_list); ``` @@ -176,7 +176,7 @@ const incremented_list: int_list = list_map(even_larger_list, increment); ```cameligo let increment (i: int) : int = i + 1 (* Creates a new list with elements incremented by 1 *) -let incremented_list: int_list = List.map larger_list increment +let incremented_list: int_list = List.map increment larger_list ``` @@ -188,7 +188,7 @@ let incremented_list: int_list = List.map larger_list increment ```pascaligo function sum(const result: int; const i: int): int is block { skip } with result + i; // Outputs 6 -const sum_of_a_list: int = list_fold(my_list, 0, sum); +const sum_of_a_list: int = list_fold(sum, my_list, 0); ``` @@ -196,7 +196,7 @@ const sum_of_a_list: int = list_fold(my_list, 0, sum); ```cameligo let sum (result: int) (i: int) : int = result + i // Outputs 6 -let sum_of_a_list: int = List.fold my_list 0 sum +let sum_of_a_list: int = List.fold sum my_list 0 ``` @@ -237,4 +237,4 @@ const first_name: string = full_name.1; let first_name: string = full_name.1 ``` - \ No newline at end of file + diff --git a/src/passes/2-simplify/pascaligo.ml b/src/passes/2-simplify/pascaligo.ml index a78b11cd7..63f9d8e70 100644 --- a/src/passes/2-simplify/pascaligo.ml +++ b/src/passes/2-simplify/pascaligo.ml @@ -1205,7 +1205,7 @@ and simpl_for_collect : Raw.for_collect -> (_ -> expression result) result = fun let lambda = e_lambda "arguments" None None for_body in let op_name = match fc.collection with | Map _ -> "MAP_FOLD" | Set _ -> "SET_FOLD" | List _ -> "LIST_FOLD" in - let fold = e_constant op_name [collect ; init_record ; lambda] in + let fold = e_constant op_name [lambda; collect ; init_record] in (* STEP 8 *) let assign_back (prev : expression option) (captured_varname : string) : expression option = let access = e_accessor (e_variable "#COMPILER#folded_record") diff --git a/src/passes/4-typer-old/typer.ml b/src/passes/4-typer-old/typer.ml index daaa684bf..e2b3aaecd 100644 --- a/src/passes/4-typer-old/typer.ml +++ b/src/passes/4-typer-old/typer.ml @@ -617,13 +617,14 @@ and type_expression' : environment -> ?tv_opt:O.type_value -> I.expression -> O. return (E_lambda {binder = fst binder ; body}) (t_function input_type output_type ()) ) | E_constant ( ("LIST_FOLD"|"MAP_FOLD"|"SET_FOLD") as opname , - [ collect ; - init_record ; + [ ( { expression = (I.E_lambda { binder = (lname, None) ; input_type = None ; output_type = None ; result }) ; - location = _ }) as _lambda + location = _ }) as _lambda ; + collect ; + init_record ; ] ) -> (* this special case is here force annotation of the untyped lambda generated by pascaligo's for_collect loop *) @@ -641,7 +642,7 @@ and type_expression' : environment -> ?tv_opt:O.type_value -> I.expression -> O. let%bind body = type_expression' ?tv_opt:(Some tv_out) e' result in let output_type = body.type_annotation in let lambda' = make_a_e (E_lambda {binder = lname ; body}) (t_function input_type output_type ()) e in - let lst' = [v_col; v_initr ; lambda'] in + let lst' = [lambda'; v_col; v_initr] in let tv_lst = List.map get_type_annotation lst' in let%bind (opname', tv) = type_constant opname tv_lst tv_opt ae.location in diff --git a/src/passes/6-transpiler/transpiler.ml b/src/passes/6-transpiler/transpiler.ml index 86bb79f3c..9fa3499da 100644 --- a/src/passes/6-transpiler/transpiler.ml +++ b/src/passes/6-transpiler/transpiler.ml @@ -390,12 +390,12 @@ and transpile_annotated_expression (ae:AST.annotated_expression) : expression re | _ -> fail @@ unsupported_iterator f.location in fun (lst : AST.annotated_expression list) -> match (lst , iterator_name) with - | [i ; f] , "ITER" | [i ; f] , "MAP" -> ( + | [f ; i] , "ITER" | [f ; i] , "MAP" -> ( let%bind f' = expression_to_iterator_body f in let%bind i' = transpile_annotated_expression i in return @@ E_iterator (iterator_name , f' , i') ) - | [ collection ; initial ; f ] , "FOLD" -> ( + | [ f ; collection ; initial ] , "FOLD" -> ( let%bind f' = expression_to_iterator_body f in let%bind initial' = transpile_annotated_expression initial in let%bind collection' = transpile_annotated_expression collection in diff --git a/src/passes/operators/operators.ml b/src/passes/operators/operators.ml index fed495a6e..256c23bf3 100644 --- a/src/passes/operators/operators.ml +++ b/src/passes/operators/operators.ml @@ -366,14 +366,14 @@ module Typer = struct let%bind () = assert_type_value_eq (src, k) in ok @@ t_option dst () - let map_iter : typer = typer_2 "MAP_ITER" @@ fun m f -> + let map_iter : typer = typer_2 "MAP_ITER" @@ fun f m -> let%bind (k, v) = get_t_map m in let%bind (arg , res) = get_t_function f in let%bind () = assert_eq_1 arg (t_pair k v ()) in let%bind () = assert_eq_1 res (t_unit ()) in ok @@ t_unit () - let map_map : typer = typer_2 "MAP_MAP" @@ fun m f -> + let map_map : typer = typer_2 "MAP_MAP" @@ fun f m -> let%bind (k, v) = get_t_map m in let%bind (arg , res) = get_t_function f in let%bind () = assert_eq_1 arg (t_pair k v ()) in @@ -578,7 +578,7 @@ module Typer = struct then ok set else simple_fail "Set_remove: elt and set don't match" - let set_iter = typer_2 "SET_ITER" @@ fun set body -> + let set_iter = typer_2 "SET_ITER" @@ fun body set -> let%bind (arg , res) = get_t_function body in let%bind () = Assert.assert_true (eq_1 res (t_unit ())) in let%bind key = get_t_set set in @@ -586,7 +586,7 @@ module Typer = struct then ok (t_unit ()) else simple_fail "bad set iter" - let list_iter = typer_2 "LIST_ITER" @@ fun lst body -> + let list_iter = typer_2 "LIST_ITER" @@ fun body lst -> let%bind (arg , res) = get_t_function body in let%bind () = Assert.assert_true (eq_1 res (t_unit ())) in let%bind key = get_t_list lst in @@ -594,14 +594,14 @@ module Typer = struct then ok (t_unit ()) else simple_fail "bad list iter" - let list_map = typer_2 "LIST_MAP" @@ fun lst body -> + let list_map = typer_2 "LIST_MAP" @@ fun body lst -> let%bind (arg , res) = get_t_function body in let%bind key = get_t_list lst in if eq_1 key arg then ok (t_list res ()) else simple_fail "bad list map" - let list_fold = typer_3 "LIST_FOLD" @@ fun lst init body -> + let list_fold = typer_3 "LIST_FOLD" @@ fun body lst init -> let%bind (arg , res) = get_t_function body in let%bind (prec , cur) = get_t_pair arg in let%bind key = get_t_list lst in @@ -615,7 +615,7 @@ module Typer = struct let%bind () = assert_eq_1 ~msg:"res init" res init in ok res - let set_fold = typer_3 "SET_FOLD" @@ fun lst init body -> + let set_fold = typer_3 "SET_FOLD" @@ fun body lst init -> let%bind (arg , res) = get_t_function body in let%bind (prec , cur) = get_t_pair arg in let%bind key = get_t_set lst in @@ -629,7 +629,7 @@ module Typer = struct let%bind () = assert_eq_1 ~msg:"res init" res init in ok res - let map_fold = typer_3 "MAP_FOLD" @@ fun map init body -> + let map_fold = typer_3 "MAP_FOLD" @@ fun body map init -> let%bind (arg , res) = get_t_function body in let%bind (prec , cur) = get_t_pair arg in let%bind (key , value) = get_t_map map in @@ -649,7 +649,7 @@ module Typer = struct whether the fold should continue or not. Necessarily then the initial value must match the input parameter of the auxillary function, and the auxillary should return type (bool * input) *) - let fold_while = typer_2 "FOLD_WHILE" @@ fun init body -> + let fold_while = typer_2 "FOLD_WHILE" @@ fun body init -> let%bind (arg, result) = get_t_function body in let%bind () = assert_eq_1 arg init in let%bind () = assert_eq_1 (t_pair (t_bool ()) init ()) result @@ -831,7 +831,7 @@ module Compiler = struct ("MAP_FIND_OPT" , simple_binary @@ prim I_GET) ; ("MAP_ADD" , simple_ternary @@ seq [dip (i_some) ; prim I_UPDATE]) ; ("MAP_UPDATE" , simple_ternary @@ prim I_UPDATE) ; - ("FOLD_WHILE" , simple_binary @@ seq [(i_push (prim T_bool) (prim D_True)) ; + ("FOLD_WHILE" , simple_binary @@ seq [i_swap ; (i_push (prim T_bool) (prim D_True)) ; prim ~children:[seq [dip i_dup; i_exec; i_unpair]] I_LOOP ; i_swap ; i_drop]) ; ("CONTINUE" , simple_unary @@ seq [(i_push (prim T_bool) (prim D_True)) ; diff --git a/src/test/contracts/list.ligo b/src/test/contracts/list.ligo index af863fb67..77f8beec3 100644 --- a/src/test/contracts/list.ligo +++ b/src/test/contracts/list.ligo @@ -24,6 +24,15 @@ const bl : foobar = list 421 ; end +function fold_op (const s: list(int)) : int is + begin + function aggregate (const prec: int; const cur: int) : int is + begin + skip + end with prec + cur + end with list_fold(aggregate, s, 10) + + function iter_op (const s : list(int)) : int is begin var r : int := 0 ; @@ -31,10 +40,10 @@ function iter_op (const s : list(int)) : int is begin r := r + i ; end with unit ; - list_iter(s , aggregate) ; + list_iter(aggregate, s) ; end with r function map_op (const s : list(int)) : list(int) is block { function increment (const i : int) : int is block { skip } with i + 1 - } with list_map(s , increment) + } with list_map(increment, s) diff --git a/src/test/contracts/list.mligo b/src/test/contracts/list.mligo index 771ac7d9a..99829fd31 100644 --- a/src/test/contracts/list.mligo +++ b/src/test/contracts/list.mligo @@ -15,11 +15,11 @@ let main (p: param) storage = let fold_op (s: int list) : int = let aggregate = fun (prec: int) (cur: int) -> prec + cur - in List.fold s 10 aggregate + in List.fold aggregate s 10 let map_op (s: int list) : int list = - List.map s (fun (cur: int) -> cur + 1) + List.map (fun (cur: int) -> cur + 1) s let iter_op (s : int list) : unit = let do_nothing = fun (_: int) -> unit - in List.iter s do_nothing + in List.iter do_nothing s diff --git a/src/test/contracts/loop.mligo b/src/test/contracts/loop.mligo index 12d1f8f5f..64ec039f5 100644 --- a/src/test/contracts/loop.mligo +++ b/src/test/contracts/loop.mligo @@ -4,7 +4,7 @@ let aux_simple (i: int) : bool * int = if i < 100 then continue (i + 1) else stop i let counter_simple (n: int) : int = - Loop.fold_while n aux_simple + Loop.fold_while aux_simple n type sum_aggregator = { counter : int ; @@ -13,21 +13,21 @@ type sum_aggregator = { let counter (n : int) : int = let initial : sum_aggregator = { counter = 0 ; sum = 0 } in - let out : sum_aggregator = Loop.fold_while initial (fun (prev: sum_aggregator) -> + let out : sum_aggregator = Loop.fold_while (fun (prev: sum_aggregator) -> if prev.counter <= n then continue ({ counter = prev.counter + 1 ; sum = prev.counter + prev.sum }) else stop ({ counter = prev.counter ; sum = prev.sum }) - ) in out.sum + ) initial in out.sum let aux_nest (prev: sum_aggregator) : bool * sum_aggregator = if prev.counter < 100 then continue ({ counter = prev.counter + 1 ; - sum = prev.sum + Loop.fold_while prev.counter aux_simple}) + sum = prev.sum + Loop.fold_while aux_simple prev.counter}) else stop ({ counter = prev.counter ; sum = prev.sum }) let counter_nest (n: int) : int = let initial : sum_aggregator = { counter = 0 ; sum = 0 } in - let out : sum_aggregator = Loop.fold_while initial aux_nest + let out : sum_aggregator = Loop.fold_while aux_nest initial in out.sum diff --git a/src/test/contracts/map.ligo b/src/test/contracts/map.ligo index aad2d3921..7f53005cd 100644 --- a/src/test/contracts/map.ligo +++ b/src/test/contracts/map.ligo @@ -52,17 +52,17 @@ function iter_op (const m : foobar) : unit is function aggregate (const i : int ; const j : int) : unit is block { if (i=j) then skip else failwith("fail") } with unit ; // map_iter(m , aggregate) ; - } with map_iter(m, aggregate) ; + } with map_iter(aggregate, m) ; function map_op (const m : foobar) : foobar is block { function increment (const i : int ; const j : int) : int is block { skip } with j + 1 ; - } with map_map(m , increment) ; + } with map_map(increment, m) ; function fold_op (const m : foobar) : int is block { function aggregate (const i : int ; const j : (int * int)) : int is block { skip } with i + j.0 + j.1 ; - } with map_fold(m , 10 , aggregate) + } with map_fold(aggregate, m , 10) function deep_op (var m : foobar) : foobar is block { diff --git a/src/test/contracts/map.mligo b/src/test/contracts/map.mligo index 1d3038f84..85592acb1 100644 --- a/src/test/contracts/map.mligo +++ b/src/test/contracts/map.mligo @@ -30,15 +30,15 @@ let get_ (m: foobar) : int option = Map.find_opt 42 m let iter_op (m : foobar) : unit = let assert_eq = fun (i: int) (j: int) -> assert (i=j) - in Map.iter m assert_eq + in Map.iter assert_eq m let map_op (m : foobar) : foobar = let increment = fun (_: int) (j: int) -> j+1 - in Map.map m increment + in Map.map increment m let fold_op (m : foobar) : foobar = let aggregate = fun (i: int) (j: int * int) -> i + j.0 + j.1 - in Map.fold m 10 aggregate + in Map.fold aggregate m 10 let deep_op (m: foobar) : foobar = let coco = 0,m in diff --git a/src/test/contracts/set_arithmetic-1.ligo b/src/test/contracts/set_arithmetic-1.ligo index d0b16263e..87e2621b2 100644 --- a/src/test/contracts/set_arithmetic-1.ligo +++ b/src/test/contracts/set_arithmetic-1.ligo @@ -7,11 +7,11 @@ function iter_op (const s : set(int)) : int is begin r := r + i ; end with unit ; - set_iter(s , aggregate) ; + set_iter(aggregate, s) ; end with r function fold_op (const s : set(int)) : int is block { function aggregate (const i : int ; const j : int) : int is i + j - } with set_fold(s , 15 , aggregate) + } with set_fold(aggregate, s , 15) diff --git a/src/test/contracts/set_arithmetic-1.mligo b/src/test/contracts/set_arithmetic-1.mligo index 811b5b7af..6c31953b4 100644 --- a/src/test/contracts/set_arithmetic-1.mligo +++ b/src/test/contracts/set_arithmetic-1.mligo @@ -3,4 +3,4 @@ let aggregate (i : int) (j : int) : int = i + j let fold_op (s : int set) : int = - Set.fold s 15 aggregate + Set.fold aggregate s 15 diff --git a/src/test/integration_tests.ml b/src/test/integration_tests.ml index d11e01f30..106475b03 100644 --- a/src/test/integration_tests.ml +++ b/src/test/integration_tests.ml @@ -741,6 +741,11 @@ let list () : unit result = let expected = ez [144 ; 51 ; 42 ; 120 ; 421] in expect_eq_evaluate program "bl" expected in + let%bind () = + expect_eq program "fold_op" + (e_list [e_int 2 ; e_int 4 ; e_int 7]) + (e_int 23) + in let%bind () = expect_eq program "iter_op" (e_list [e_int 2 ; e_int 4 ; e_int 7])