diff --git a/src/passes/operators/operators.ml b/src/passes/operators/operators.ml index 06999bded..90e6dffad 100644 --- a/src/passes/operators/operators.ml +++ b/src/passes/operators/operators.ml @@ -194,6 +194,10 @@ module Simplify = struct ("List.map" , "LIST_MAP") ; ("List.fold" , "LIST_FOLD") ; + ("Loop.fold_while" , "FOLD_WHILE") ; + ("continue" , "CONTINUE") ; + ("stop" , "STOP") ; + ("Operation.transaction" , "CALL") ; ("Operation.get_contract" , "CONTRACT") ; ("int" , "INT") ; @@ -563,6 +567,25 @@ module Typer = struct let%bind () = assert_eq_1 ~msg:"res init" res init in ok res + (** FOLD_WHILE is a fold operation that takes an initial value of a certain type + and then iterates on it until a condition is reached. The auxillary function + that does the fold returns either boolean true or boolean false to indicate + whether the fold should continue or not. Necessarily then the initial value + must match the input parameter of the auxillary function, and the auxillary + should return type (bool * input) *) + let fold_while = typer_2 "FOLD_WHILE" @@ fun init body -> + let%bind (arg, result) = get_t_function body in + let%bind () = assert_eq_1 arg init in + let%bind () = assert_eq_1 (t_pair (t_bool ()) init ()) result + in ok init + + (* Continue and Stop are just syntactic sugar for building a pair (bool * a') *) + let continue = typer_1 "CONTINUE" @@ fun arg -> + ok @@ t_pair (t_bool ()) arg () + + let stop = typer_1 "STOP" @@ fun arg -> + ok (t_pair (t_bool ()) arg ()) + let not_ = typer_1 "NOT" @@ fun elt -> if eq_1 elt (t_bool ()) then ok @@ t_bool () @@ -641,6 +664,9 @@ module Typer = struct map_find_opt ; map_map ; map_fold ; + fold_while ; + continue ; + stop ; map_iter ; map_get_force ; map_get ; @@ -726,6 +752,13 @@ module Compiler = struct ("MAP_FIND_OPT" , simple_binary @@ prim I_GET) ; ("MAP_ADD" , simple_ternary @@ seq [dip (i_some) ; prim I_UPDATE]) ; ("MAP_UPDATE" , simple_ternary @@ prim I_UPDATE) ; + ("FOLD_WHILE" , simple_binary @@ seq [(i_push (prim T_bool) (prim D_True)) ; + prim ~children:[seq [dip i_dup; i_exec; i_unpair]] I_LOOP ; + i_swap ; i_drop]) ; + ("CONTINUE" , simple_unary @@ seq [(i_push (prim T_bool) (prim D_True)) ; + i_pair]) ; + ("STOP" , simple_unary @@ seq [(i_push (prim T_bool) (prim D_False)) ; + i_pair]) ; ("SIZE" , simple_unary @@ prim I_SIZE) ; ("FAILWITH" , simple_unary @@ prim I_FAILWITH) ; ("ASSERT_INFERRED" , simple_binary @@ i_if (seq [i_failwith]) (seq [i_drop ; i_push_unit])) ; diff --git a/src/test/contracts/loop.mligo b/src/test/contracts/loop.mligo new file mode 100644 index 000000000..12d1f8f5f --- /dev/null +++ b/src/test/contracts/loop.mligo @@ -0,0 +1,33 @@ +(* Test loops in CameLIGO *) + +let aux_simple (i: int) : bool * int = + if i < 100 then continue (i + 1) else stop i + +let counter_simple (n: int) : int = + Loop.fold_while n aux_simple + +type sum_aggregator = { + counter : int ; + sum : int ; +} + +let counter (n : int) : int = + let initial : sum_aggregator = { counter = 0 ; sum = 0 } in + let out : sum_aggregator = Loop.fold_while initial (fun (prev: sum_aggregator) -> + if prev.counter <= n then + continue ({ counter = prev.counter + 1 ; sum = prev.counter + prev.sum }) + else + stop ({ counter = prev.counter ; sum = prev.sum }) + ) in out.sum + +let aux_nest (prev: sum_aggregator) : bool * sum_aggregator = + if prev.counter < 100 then + continue ({ counter = prev.counter + 1 ; + sum = prev.sum + Loop.fold_while prev.counter aux_simple}) + else + stop ({ counter = prev.counter ; sum = prev.sum }) + +let counter_nest (n: int) : int = + let initial : sum_aggregator = { counter = 0 ; sum = 0 } in + let out : sum_aggregator = Loop.fold_while initial aux_nest + in out.sum diff --git a/src/test/integration_tests.ml b/src/test/integration_tests.ml index a77ff0bd3..62615e80f 100644 --- a/src/test/integration_tests.ml +++ b/src/test/integration_tests.ml @@ -805,6 +805,24 @@ let for_fail () : unit result = let%bind () = expect_fail program "main" (e_nat 0) in ok () *) +let loop_mligo () : unit result = + let%bind program = mtype_file "./contracts/loop.mligo" in + let%bind () = + let input = e_int 0 in + let expected = e_int 100 in + expect_eq program "counter_simple" input expected + in + let%bind () = + let input = e_int 100 in + let expected = e_int 5050 in + expect_eq program "counter" input expected + in + let%bind () = + let input = e_int 100 in + let expected = e_int 10000 in + expect_eq program "counter_nest" input expected + in ok () + let matching () : unit result = let%bind program = type_file "./contracts/match.ligo" in let%bind () = @@ -1152,6 +1170,7 @@ let main = test_suite "Integration (End to End)" [ test "big_map (mligo)" mbig_map ; test "list" list ; test "loop" loop ; + test "loop (mligo)" loop_mligo ; test "matching" matching ; test "declarations" declarations ; test "quote declaration" quote_declaration ;