diff --git a/src/passes/2-simplify/dune b/src/passes/2-simplify/dune index 9649d13dc..e27b5139d 100644 --- a/src/passes/2-simplify/dune +++ b/src/passes/2-simplify/dune @@ -6,6 +6,7 @@ tezos-utils parser ast_simplified + self_ast_simplified operators) (modules ligodity pascaligo simplify) (preprocess diff --git a/src/passes/2-simplify/pascaligo.ml b/src/passes/2-simplify/pascaligo.ml index d9fb6266d..4dee47dec 100644 --- a/src/passes/2-simplify/pascaligo.ml +++ b/src/passes/2-simplify/pascaligo.ml @@ -68,16 +68,6 @@ module Errors = struct ] in error ~data title message - let unsupported_for_loops region = - let title () = "bounded iterators" in - let message () = - Format.asprintf "only simple for loops are supported for now" in - let data = [ - ("loop_loc", - fun () -> Format.asprintf "%a" Location.pp_lift @@ region) - ] in - error ~data title message - let unsupported_non_var_pattern p = let title () = "pattern is not a variable" in let message () = @@ -137,6 +127,17 @@ module Errors = struct ] in error ~data title message + let unsupported_deep_access_for_collection for_col = + let title () = "deep access in loop over collection" in + let message () = + Format.asprintf "currently, we do not support deep \ + accesses in loops over collection" in + let data = [ + ("pattern_loc", + fun () -> Format.asprintf "%a" Location.pp_lift @@ for_col.Region.region) + ] in + error ~data title message + (* Logging *) let simplifying_instruction t = @@ -667,8 +668,14 @@ and simpl_single_instruction : Raw.instruction -> (_ -> expression result) resul let%bind body = simpl_block l.block.value in let%bind body = body None in return_statement @@ e_loop cond body - | Loop (For (ForInt {region; _} | ForCollect {region ; _})) -> - fail @@ unsupported_for_loops region + | Loop (For (ForInt fi)) -> + let%bind loop = simpl_for_int fi.value in + let%bind loop = loop None in + return_statement @@ loop + | Loop (For (ForCollect fc)) -> + let%bind loop = simpl_for_collect fc.value in + let%bind loop = loop None in + return_statement @@ loop | Cond c -> ( let (c , loc) = r_split c in let%bind expr = simpl_expression c.test in @@ -961,5 +968,206 @@ and simpl_statements : Raw.statements -> (_ -> expression result) result = and simpl_block : Raw.block -> (_ -> expression result) result = fun t -> simpl_statements t.statements +and simpl_for_int : Raw.for_int -> (_ -> expression result) result = fun fi -> + (* cond part *) + let var = e_variable fi.assign.value.name.value in + let%bind value = simpl_expression fi.assign.value.expr in + let%bind bound = simpl_expression fi.bound in + let comp = e_annotation (e_constant "LE" [var ; bound]) t_bool + in + (* body part *) + let%bind body = simpl_block fi.block.value in + let%bind body = body None in + let step = e_int 1 in + let ctrl = e_assign + fi.assign.value.name.value [] (e_constant "ADD" [ var ; step ]) in + let rec add_to_seq expr = match expr.expression with + | E_sequence (_,a) -> add_to_seq a + | _ -> e_sequence body ctrl in + let body' = add_to_seq body in + let loop = e_loop comp body' in + return_statement @@ e_let_in (fi.assign.value.name.value, Some t_int) value loop + +(** simpl_for_collect + For loops over collections, like + + ``` concrete syntax : + for x : int in set myset + begin + myint := myint + x ; + myst := myst ^ "to" ; + end + ``` + + are implemented using a MAP_FOLD, LIST_FOLD or SET_FOLD: + + ``` pseudo Ast_simplified + let #COMPILER#folded_record = list_fold( mylist , + record st = st; acc = acc; end; + lamby = fun arguments -> ( + let #COMPILER#acc = arguments.0 in + let #COMPILER#elt = arguments.1 in + #COMPILER#acc.myint := #COMPILER#acc.myint + #COMPILER#elt ; + #COMPILER#acc.myst := #COMPILER#acc.myst ^ "to" ; + #COMPILER#acc + ) + ) in + { + myst := #COMPILER#folded_record.myst ; + myint := #COMPILER#folded_record.myint ; + } + ``` + + We are performing the following steps: + 1) Simplifying the for body using ̀simpl_block` + + 2) Detect the free variables and build a list of their names + (myint and myst in the previous example) + + 3) Build the initial record (later passed as 2nd argument of + `MAP/SET/LIST_FOLD`) capturing the environment using the + free variables list of (2) + + 4) In the filtered body of (1), replace occurences: + - free variable of name X as rhs ==> accessor `#COMPILER#acc.X` + - free variable of name X as lhs ==> accessor `#COMPILER#acc.X` + And, in the case of a map: + - references to the iterated key ==> variable `#COMPILER#elt_key` + - references to the iterated value ==> variable `#COMPILER#elt_value` + in the case of a set/list: + - references to the iterated value ==> variable `#COMPILER#elt` + + 5) Append the return value to the body + + 6) Prepend the declaration of the lambda arguments to the body which + is a serie of `let .. in`'s + Note that the parameter of the lambda ̀arguments` is a tree of + tuple holding: + * In the case of `list` or ̀set`: + ( folding record , current list/set element ) as + ( #COMPILER#acc , #COMPILER#elt ) + * In the case of `map`: + ( folding record , current map key , current map value ) as + ( #COMPILER#acc , #COMPILER#elt_key , #COMPILER#elt_value ) + + 7) Build the lambda using the final body of (6) + + 8) Build a sequence of assignments for all the captured variables + to their new value, namely an access to the folded record + (#COMPILER#folded_record) + + 9) Attach the sequence of 8 to the ̀let .. in` declaration + of #COMPILER#folded_record + +**) +and simpl_for_collect : Raw.for_collect -> (_ -> expression result) result = fun fc -> + (* STEP 1 *) + let%bind for_body = simpl_block fc.block.value in + let%bind for_body = for_body None in + (* STEP 2 *) + let%bind captured_name_list = Self_ast_simplified.fold_expression + (fun (prev : type_name list) (ass_exp : expression) -> + match ass_exp.expression with + | E_assign ( name , _ , _ ) -> + if (String.contains name '#') then + ok prev + else + ok (name::prev) + | _ -> ok prev ) + [] + for_body in + (* STEP 3 *) + let add_to_record (prev: expression type_name_map) (captured_name: string) = + SMap.add captured_name (e_variable captured_name) prev in + let init_record = e_record (List.fold_left add_to_record SMap.empty captured_name_list) in + (* STEP 4 *) + let replace exp = + match exp.expression with + (* replace references to fold accumulator as rhs *) + | E_assign ( name , path , expr ) -> ( + match path with + | [] -> ok @@ e_assign "#COMPILER#acc" [Access_record name] expr + (* This fails for deep accesses, see LIGO-131 LIGO-134 *) + | _ -> + (* ok @@ e_assign "#COMPILER#acc" ((Access_record name)::path) expr) *) + fail @@ unsupported_deep_access_for_collection fc.block ) + | E_variable name -> ( + if (List.mem name captured_name_list) then + (* replace references to fold accumulator as lhs *) + ok @@ e_accessor (e_variable "#COMPILER#acc") [Access_record name] + else match fc.collection with + (* loop on map *) + | Map _ -> + let k' = e_variable "#COMPILER#collec_elt_k" in + if ( name = fc.var.value ) then + ok @@ k' (* replace references to the the key *) + else ( + match fc.bind_to with + | Some (_,v) -> + let v' = e_variable "#COMPILER#collec_elt_v" in + if ( name = v.value ) then + ok @@ v' (* replace references to the the value *) + else ok @@ exp + | None -> ok @@ exp + ) + (* loop on set or list *) + | (Set _ | List _) -> + if (name = fc.var.value ) then + (* replace references to the collection element *) + ok @@ (e_variable "#COMPILER#collec_elt") + else ok @@ exp + ) + | _ -> ok @@ exp in + let%bind for_body = Self_ast_simplified.map_expression replace for_body in + (* STEP 5 *) + let rec add_return (expr : expression) = match expr.expression with + | E_sequence (a,b) -> e_sequence a (add_return b) + | _ -> e_sequence expr (e_variable "#COMPILER#acc") in + let for_body = add_return for_body in + (* STEP 6 *) + let for_body = + let ( arg_access: Types.access_path -> expression ) = e_accessor (e_variable "arguments") in + ( match fc.collection with + | Map _ -> + (* let acc = arg_access [Access_tuple 0 ; Access_tuple 0] in + let collec_elt_v = arg_access [Access_tuple 1 ; Access_tuple 0] in + let collec_elt_k = arg_access [Access_tuple 1 ; Access_tuple 1] in *) + (* The above should work, but not yet (see LIGO-131) *) + let temp_kv = arg_access [Access_tuple 1] in + let acc = arg_access [Access_tuple 0] in + let collec_elt_v = e_accessor (e_variable "#COMPILER#temp_kv") [Access_tuple 0] in + let collec_elt_k = e_accessor (e_variable "#COMPILER#temp_kv") [Access_tuple 1] in + e_let_in ("#COMPILER#acc", None) acc @@ + e_let_in ("#COMPILER#temp_kv", None) temp_kv @@ + e_let_in ("#COMPILER#collec_elt_k", None) collec_elt_v @@ + e_let_in ("#COMPILER#collec_elt_v", None) collec_elt_k (for_body) + | _ -> + let acc = arg_access [Access_tuple 0] in + let collec_elt = arg_access [Access_tuple 1] in + e_let_in ("#COMPILER#acc", None) acc @@ + e_let_in ("#COMPILER#collec_elt", None) collec_elt (for_body) + ) in + (* STEP 7 *) + let%bind collect = simpl_expression fc.expr in + let lambda = e_lambda "arguments" None None for_body in + let op_name = match fc.collection with + | Map _ -> "MAP_FOLD" | Set _ -> "SET_FOLD" | List _ -> "LIST_FOLD" in + let fold = e_constant op_name [collect ; init_record ; lambda] in + (* STEP 8 *) + let assign_back (prev : expression option) (captured_varname : string) : expression option = + let access = e_accessor (e_variable "#COMPILER#folded_record") + [Access_record captured_varname] in + let assign = e_assign captured_varname [] access in + match prev with + | None -> Some assign + | Some p -> Some (e_sequence p assign) in + let reassign_sequence = List.fold_left assign_back None captured_name_list in + (* STEP 9 *) + let final_sequence = match reassign_sequence with + (* None case means that no variables were captured *) + | None -> e_skip () + | Some seq -> e_let_in ("#COMPILER#folded_record", None) fold seq in + return_statement @@ final_sequence + let simpl_program : Raw.ast -> program result = fun t -> - bind_list @@ List.map simpl_declaration @@ nseq_to_list t.decl + bind_list @@ List.map simpl_declaration @@ nseq_to_list t.decl \ No newline at end of file diff --git a/src/passes/3-self_ast_simplified/helpers.ml b/src/passes/3-self_ast_simplified/helpers.ml index 0793e8420..04b641f87 100644 --- a/src/passes/3-self_ast_simplified/helpers.ml +++ b/src/passes/3-self_ast_simplified/helpers.ml @@ -1,8 +1,93 @@ open Ast_simplified open Trace -type mapper = expression -> expression result +type 'a folder = 'a -> expression -> 'a result +let rec fold_expression : 'a folder -> 'a -> expression -> 'a result = fun f init e -> + let self = fold_expression f in + let%bind init' = f init e in + match e.expression with + | E_literal _ | E_variable _ | E_skip -> ok init' + | E_list lst | E_set lst | E_tuple lst | E_constant (_ , lst) -> ( + let%bind res' = bind_fold_list self init' lst in + ok res' + ) + | E_map lst | E_big_map lst -> ( + let%bind res' = bind_fold_list (bind_fold_pair self) init' lst in + ok res' + ) + | E_look_up ab | E_sequence ab | E_loop ab | E_application ab -> ( + let%bind res' = bind_fold_pair self init' ab in + ok res' + ) + | E_lambda { binder = _ ; input_type = _ ; output_type = _ ; result = e } + | E_annotation (e , _) | E_constructor (_ , e) -> ( + let%bind res' = self init' e in + ok res' + ) + | E_assign (_ , path , e) | E_accessor (e , path) -> ( + let%bind res' = fold_path f init' path in + let%bind res' = self res' e in + ok res' + ) + | E_matching (e , cases) -> ( + let%bind res = self init' e in + let%bind res = fold_cases f res cases in + ok res + ) + | E_record m -> ( + let aux init'' _ expr = + let%bind res' = fold_expression self init'' expr in + ok res' + in + let%bind res = bind_fold_smap aux (ok init') m in + ok res + ) + | E_let_in { binder = _ ; rhs ; result } -> ( + let%bind res = self init' rhs in + let%bind res = self res result in + ok res + ) +and fold_path : 'a folder -> 'a -> access_path -> 'a result = fun f init p -> bind_fold_list (fold_access f) init p + +and fold_access : 'a folder -> 'a -> access -> 'a result = fun f init a -> + match a with + | Access_map e -> ( + let%bind e' = fold_expression f init e in + ok e' + ) + | _ -> ok init + +and fold_cases : 'a folder -> 'a -> matching_expr -> 'a result = fun f init m -> + match m with + | Match_bool { match_true ; match_false } -> ( + let%bind res = fold_expression f init match_true in + let%bind res = fold_expression f res match_false in + ok res + ) + | Match_list { match_nil ; match_cons = (_ , _ , cons) } -> ( + let%bind res = fold_expression f init match_nil in + let%bind res = fold_expression f res cons in + ok res + ) + | Match_option { match_none ; match_some = (_ , some) } -> ( + let%bind res = fold_expression f init match_none in + let%bind res = fold_expression f res some in + ok res + ) + | Match_tuple (_ , e) -> ( + let%bind res = fold_expression f init e in + ok res + ) + | Match_variant lst -> ( + let aux init' ((_ , _) , e) = + let%bind res' = fold_expression f init' e in + ok res' in + let%bind res = bind_fold_list aux init lst in + ok res + ) + +type mapper = expression -> expression result let rec map_expression : mapper -> expression -> expression result = fun f e -> let self = map_expression f in let%bind e' = f e in diff --git a/src/passes/3-self_ast_simplified/self_ast_simplified.ml b/src/passes/3-self_ast_simplified/self_ast_simplified.ml index aa18b4a8c..a1ce4b580 100644 --- a/src/passes/3-self_ast_simplified/self_ast_simplified.ml +++ b/src/passes/3-self_ast_simplified/self_ast_simplified.ml @@ -21,3 +21,7 @@ let all_program = let all_expression = let all_p = List.map Helpers.map_expression all in bind_chain all_p + +let map_expression = Helpers.map_expression + +let fold_expression = Helpers.fold_expression diff --git a/src/passes/4-typer/typer.ml b/src/passes/4-typer/typer.ml index 28258b9fb..99d8adf3c 100644 --- a/src/passes/4-typer/typer.ml +++ b/src/passes/4-typer/typer.ml @@ -615,6 +615,36 @@ and type_expression : environment -> ?tv_opt:O.type_value -> I.expression -> O.a let output_type = body.type_annotation in return (E_lambda {binder = fst binder ; body}) (t_function input_type output_type ()) ) + | E_constant ( ("LIST_FOLD"|"MAP_FOLD"|"SET_FOLD") as opname , + [ collect ; + init_record ; + ( { expression = (I.E_lambda { binder = (lname, None) ; + input_type = None ; + output_type = None ; + result }) ; + location = _ }) as _lambda + ] ) -> + (* this special case is here force annotation of the untyped lambda + generated by pascaligo's for_collect loop *) + let%bind (v_col , v_initr ) = bind_map_pair (type_expression e) (collect , init_record ) in + let tv_col = get_type_annotation v_col in (* this is the type of the collection *) + let tv_out = get_type_annotation v_initr in (* this is the output type of the lambda*) + let%bind input_type = match tv_col.type_value' with + | O.T_constant ( ("list"|"set") , t) -> ok @@ t_tuple (tv_out::t) () + | O.T_constant ( "map" , t) -> ok @@ t_tuple (tv_out::[(t_tuple t ())]) () + | _ -> + let wtype = Format.asprintf + "Loops over collections expect lists, sets or maps, got type %a" O.PP.type_value tv_col in + fail @@ simple_error wtype in + let e' = Environment.add_ez_binder lname input_type e in + let%bind body = type_expression ?tv_opt:(Some tv_out) e' result in + let output_type = body.type_annotation in + let lambda' = make_a_e (E_lambda {binder = lname ; body}) (t_function input_type output_type ()) e in + let lst' = [v_col; v_initr ; lambda'] in + let tv_lst = List.map get_type_annotation lst' in + let%bind (opname', tv) = + type_constant opname tv_lst tv_opt ae.location in + return (E_constant (opname' , lst')) tv | E_constant (name, lst) -> let%bind lst' = bind_list @@ List.map (type_expression e) lst in let tv_lst = List.map get_type_annotation lst' in diff --git a/src/test/contracts/loop.ligo b/src/test/contracts/loop.ligo index 03cc751a7..2f250cf18 100644 --- a/src/test/contracts/loop.ligo +++ b/src/test/contracts/loop.ligo @@ -16,12 +16,140 @@ function while_sum (var n : nat) : nat is block { } } with r -(* function for_sum (var n : nat) : nat is block { - for i := 1 to 100 +function for_sum (var n : nat) : int is block { + var acc : int := 0 ; + for i := 1 to int(n) begin - n := n + 1; - end } - with n *) + acc := acc + i ; + end +} with acc + +function for_collection_list (var nee : unit) : (int * string) is block { + var acc : int := 0 ; + var st : string := "to" ; + var mylist : list(int) := list 1 ; 1 ; 1 end ; + for x : int in list mylist + begin + acc := acc + x ; + st := st^"to" ; + end +} with (acc, st) + +function for_collection_set (var nee : unit) : (int * string) is block { + var acc : int := 0 ; + var st : string := "to" ; + var myset : set(int) := set 1 ; 2 ; 3 end ; + for x : int in set myset + begin + acc := acc + x ; + st := st^"to" ; + end +} with (acc, st) + +function for_collection_if_and_local_var (var nee : unit) : int is block { + var acc : int := 0 ; + const theone : int = 1 ; + var myset : set(int) := set 1 ; 2 ; 3 end ; + for x : int in set myset + begin + const thetwo : int = 2 ; + if (x=theone) then + acc := acc + x ; + else if (x=thetwo) then + acc := acc + thetwo ; + else + acc := acc + 10 ; + end +} with acc + +function for_collection_rhs_capture (var nee : unit) : int is block { + var acc : int := 0 ; + const mybigint : int = 1000 ; + var myset : set(int) := set 1 ; 2 ; 3 end ; + for x : int in set myset + begin + if (x=1) then + acc := acc + mybigint ; + else + acc := acc + 10 ; + end +} with acc + +function for_collection_proc_call (var nee : unit) : int is block { + var acc : int := 0 ; + var myset : set(int) := set 1 ; 2 ; 3 end ; + for x : int in set myset + begin + if (x=1) then + acc := acc + for_collection_rhs_capture(unit) ; + else + acc := acc + 10 ; + end +} with acc + +function for_collection_comp_with_acc (var nee : unit) : int is block { + var myint : int := 0 ; + var mylist : list(int) := list 1 ; 10 ; 15 end; + for x : int in list mylist + begin + if (x < myint) then skip ; + else myint := myint + 10 ; + end +} with myint + +function for_collection_with_patches (var nee : unit) : map(string,int) is block { + var myint : int := 12 ; + var mylist : list(string) := list "I" ; "am" ; "foo" end; + var mymap : map(string,int) := map end; + for x : string in list mylist + begin + patch mymap with map [ x -> myint ]; + end +} with mymap + +function for_collection_empty (var nee : unit) : int is block { + var acc : int := 0 ; + var myset : set(int) := set 1 ; 2 ; 3 end ; + for x : int in set myset + begin + skip ; + end +} with acc + +function for_collection_map_kv (var nee : unit) : (int * string) is block { + var acc : int := 0 ; + var st : string := "" ; + var mymap : map(string,int) := map "1" -> 1 ; "2" -> 2 ; "3" -> 3 end ; + for k -> v : (string * int) in map mymap + begin + acc := acc + v ; + st := st^k ; + end +} with (acc, st) + +function for_collection_map_k (var nee : unit) : string is block { + var st : string := "" ; + var mymap : map(string,int) := map "1" -> 1 ; "2" -> 2 ; "3" -> 3 end ; + for k : string in map mymap + begin + st := st^k ; + end +} with st + +// function nested_for_collection (var nee : unit) : (int*string) is block { +// var myint : int := 0; +// var myst : string := ""; +// var mylist : list(int) := list 1 ; 2 ; 3 end ; +// for i : int in list mylist +// begin +// myint := myint + i ; +// var myset : set(string) := set "1" ; "2" ; "3" end ; +// for st : string in set myset +// begin +// myst := myst ^ st ; +// end +// end +// } with (myint,myst) function dummy (const n : nat) : nat is block { while (False) block { skip } diff --git a/src/test/integration_tests.ml b/src/test/integration_tests.ml index ed60a0966..10f4d4d06 100644 --- a/src/test/integration_tests.ml +++ b/src/test/integration_tests.ml @@ -647,24 +647,58 @@ let loop () : unit result = let%bind () = let make_input = e_nat in let make_expected = e_nat in - expect_eq_n_pos program "dummy" make_input make_expected - in + expect_eq_n_pos program "dummy" make_input make_expected in let%bind () = let make_input = e_nat in let make_expected = e_nat in - expect_eq_n_pos_mid program "counter" make_input make_expected - in + expect_eq_n_pos_mid program "counter" make_input make_expected in let%bind () = let make_input = e_nat in let make_expected = fun n -> e_nat (n * (n + 1) / 2) in - expect_eq_n_pos_mid program "while_sum" make_input make_expected - in(* For loop is currently unsupported - - let%bind () = + expect_eq_n_pos_mid program "while_sum" make_input make_expected in + let%bind () = let make_input = e_nat in - let make_expected = fun n -> e_nat (n * (n + 1) / 2) in - expect_eq_n_pos_mid program "for_sum" make_input make_expected - in *) + let make_expected = fun n -> e_int (n * (n + 1) / 2) in + expect_eq_n_pos_mid program "for_sum" make_input make_expected in + let input = e_unit () in + let%bind () = + let expected = e_pair (e_int 3) (e_string "totototo") in + expect_eq program "for_collection_list" input expected in + let%bind () = + let expected = e_pair (e_int 6) (e_string "totototo") in + expect_eq program "for_collection_set" input expected in + let%bind () = + let expected = e_pair (e_int 6) (e_string "123") in + expect_eq program "for_collection_map_kv" input expected in + let%bind () = + let expected = (e_string "123") in + expect_eq program "for_collection_map_k" input expected in + let%bind () = + let expected = (e_int 0) in + expect_eq program "for_collection_empty" input expected in + let%bind () = + let expected = (e_int 13) in + expect_eq program "for_collection_if_and_local_var" input expected in + let%bind () = + let expected = (e_int 1020) in + expect_eq program "for_collection_rhs_capture" input expected in + let%bind () = + let expected = (e_int 1040) in + expect_eq program "for_collection_proc_call" input expected in + let%bind () = + let expected = (e_int 20) in + expect_eq program "for_collection_comp_with_acc" input expected in + (* let%bind () = + let expected = e_pair (e_int 6) (e_string "123123123") in + expect_eq program "nested_for_collection" input expected in *) + let%bind () = + let ez lst = + let open Ast_simplified.Combinators in + let lst' = List.map (fun (x, y) -> e_string x, e_int y) lst in + e_typed_map lst' t_string t_int + in + let expected = ez [ ("I" , 12) ; ("am" , 12) ; ("foo" , 12) ] in + expect_eq program "for_collection_with_patches" input expected in ok () (* Don't know how to assert parse error happens in this test framework