diff --git a/src/passes/2-simplify/pascaligo.ml b/src/passes/2-simplify/pascaligo.ml index 4ae15d8dd..66fe46481 100644 --- a/src/passes/2-simplify/pascaligo.ml +++ b/src/passes/2-simplify/pascaligo.ml @@ -68,16 +68,6 @@ module Errors = struct ] in error ~data title message - (* let unsupported_for_loops region = - let title () = "bounded iterators" in - let message () = - Format.asprintf "only simple for loops are supported for now" in - let data = [ - ("loop_loc", - fun () -> Format.asprintf "%a" Location.pp_lift @@ region) - ] in - error ~data title message *) - let unsupported_non_var_pattern p = let title () = "pattern is not a variable" in let message () = @@ -148,6 +138,16 @@ module Errors = struct ] in error ~data title message + let unsupported_for_collect_map for_col = + let title () = "for loop over map" in + let message () = + Format.asprintf "for loops over map are not supported yet" in + let data = [ + ("loop_loc", + fun () -> Format.asprintf "%a" Location.pp_lift @@ for_col.Region.region) + ] in + error ~data title message + (* Logging *) let simplifying_instruction t = @@ -999,6 +999,7 @@ and simpl_for_int : Raw.for_int -> (_ -> expression result) result = fun fi -> return_statement @@ e_let_in (fi.assign.value.name.value, Some t_int) value loop and simpl_for_collect : Raw.for_collect -> (_ -> expression result) result = fun fc -> + match fc.collection with | Map _ -> fail @@ unsupported_for_collect_map fc.block | _ -> let statements = npseq_to_list fc.block.value.statements in (* build initial record *) let filter_assignments (el : Raw.statement) : Raw.instruction option = match el with @@ -1027,16 +1028,43 @@ and simpl_for_collect : Raw.for_collect -> (_ -> expression result) result = fun (* replace references to fold accumulator as rhs *) | E_assign ( name , path , expr ) -> ( match path with | [] -> ok @@ e_assign "_COMPILER_acc" [Access_record name] expr - (* This fails for deep accesses, see LIGO-131 *) - | _ -> fail @@ unsupported_deep_access_for_collection fc.block ) - | E_variable name -> - if (name = fc.var.value ) then - (* replace references to the collection element *) - ok @@ (e_variable "_COMPILER_collec_elt") - else if (List.mem name captured_name_list) then - (* replace references to fold accumulator as lhs *) - ok @@ e_accessor (e_variable "_COMPILER_acc") [Access_record name] - else ok @@ exp + (* This fails for deep accesses, see LIGO-131 LIGO-134 *) + | _ -> + (* ok @@ e_assign "_COMPILER_acc" ((Access_record name)::path) expr) *) + fail @@ unsupported_deep_access_for_collection fc.block ) + | E_variable name -> ( match fc.collection with + (* loop on map *) + | Map _ -> + let k' = e_variable "_COMPILER_collec_elt_k" in + let v' = e_variable "_COMPILER_collec_elt_v" in + ( match fc.bind_to with + | Some (_,v) -> + if ( name = fc.var.value ) then + ok @@ k' (* replace references to the the key *) + else if ( name = v.value ) then + ok @@ v' (* replace references to the the value *) + else if (List.mem name captured_name_list) then + (* replace references to fold accumulator as lhs *) + ok @@ e_accessor (e_variable "_COMPILER_acc") [Access_record name] + else ok @@ exp + | None -> + if ( name = fc.var.value ) then + ok @@ k' (* replace references to the key *) + else if (List.mem name captured_name_list) then + (* replace references to fold accumulator as lhs *) + ok @@ e_accessor (e_variable "_COMPILER_acc") [Access_record name] + else ok @@ exp + ) + (* loop on set or list *) + | (Set _ | List _) -> + if (name = fc.var.value ) then + (* replace references to the collection element *) + ok @@ (e_variable "_COMPILER_collec_elt") + else if (List.mem name captured_name_list) then + (* replace references to fold accumulator as lhs *) + ok @@ e_accessor (e_variable "_COMPILER_acc") [Access_record name] + else ok @@ exp + ) | _ -> ok @@ exp in let%bind for_body = Self_ast_simplified.map_expression replace for_body in (* append the return value (the accumulator) to the for body *) @@ -1044,12 +1072,24 @@ and simpl_for_collect : Raw.for_collect -> (_ -> expression result) result = fun | E_sequence (a,b) -> e_sequence a (add_return b) | _ -> e_sequence expr (e_variable "_COMPILER_acc") in let for_body = add_return for_body in - (* prepend for body with args declaration (accumulator and collection element)*) + (* prepend for body with args declaration (accumulator and collection elements *) let%bind elt_type = simpl_type_expression fc.elt_type in - let acc = e_accessor (e_variable "arguments") [Access_tuple 0] in - let collec_elt = e_accessor (e_variable "arguments") [Access_tuple 1] in - let for_body = e_let_in ("_COMPILER_acc", None) acc @@ - e_let_in ("_COMPILER_collec_elt", Some elt_type) collec_elt (for_body) in + let for_body = + let ( arg_access: Types.access_path -> expression ) = e_accessor (e_variable "arguments") in + ( match fc.collection with + | Map _ -> + let acc = arg_access [Access_tuple 0 ; Access_tuple 0] in + let collec_elt_v = arg_access [Access_tuple 1 ; Access_tuple 0] in + let collec_elt_k = arg_access [Access_tuple 1 ; Access_tuple 1] in + e_let_in ("_COMPILER_acc", None) acc @@ + e_let_in ("_COMPILER_collec_elt_k", None) collec_elt_v @@ + e_let_in ("_COMPILER_collec_elt_v", None) collec_elt_k (for_body) + | _ -> + let acc = arg_access [Access_tuple 0] in + let collec_elt = arg_access [Access_tuple 1] in + e_let_in ("_COMPILER_acc", None) acc @@ + e_let_in ("_COMPILER_collec_elt", Some elt_type) collec_elt (for_body) + ) in (* build the X_FOLD constant *) let%bind collect = simpl_expression fc.expr in let lambda = e_lambda "arguments" None None for_body in diff --git a/src/passes/4-typer/typer.ml b/src/passes/4-typer/typer.ml index 832e7b04f..99d8adf3c 100644 --- a/src/passes/4-typer/typer.ml +++ b/src/passes/4-typer/typer.ml @@ -629,13 +629,13 @@ and type_expression : environment -> ?tv_opt:O.type_value -> I.expression -> O.a let%bind (v_col , v_initr ) = bind_map_pair (type_expression e) (collect , init_record ) in let tv_col = get_type_annotation v_col in (* this is the type of the collection *) let tv_out = get_type_annotation v_initr in (* this is the output type of the lambda*) - let%bind col_inner_type = match tv_col.type_value' with - | O.T_constant ( ("list"|"set"|"map") , [t]) -> ok t + let%bind input_type = match tv_col.type_value' with + | O.T_constant ( ("list"|"set") , t) -> ok @@ t_tuple (tv_out::t) () + | O.T_constant ( "map" , t) -> ok @@ t_tuple (tv_out::[(t_tuple t ())]) () | _ -> let wtype = Format.asprintf - "Loops over collections expect lists, sets or maps, type %a" O.PP.type_value tv_col in + "Loops over collections expect lists, sets or maps, got type %a" O.PP.type_value tv_col in fail @@ simple_error wtype in - let input_type = t_tuple (tv_out::[col_inner_type]) () in let e' = Environment.add_ez_binder lname input_type e in let%bind body = type_expression ?tv_opt:(Some tv_out) e' result in let output_type = body.type_annotation in diff --git a/src/test/contracts/loop.ligo b/src/test/contracts/loop.ligo index eaa429df7..f559c2816 100644 --- a/src/test/contracts/loop.ligo +++ b/src/test/contracts/loop.ligo @@ -39,7 +39,6 @@ function for_collection_ (var nee : unit; var nuu : unit) : (int * string) is bl record st = st; acc = acc; end; var folded_record : (record st : string; acc : int end ) := list_fold(mylist , init_record , lamby) ; - skip ; st := folded_record.st ; acc := folded_record.acc ; } with (folded_record.acc , folded_record.st) @@ -66,6 +65,17 @@ function for_collection_set (var nee : unit) : (int * string) is block { end } with (acc, st) +// function for_collection_map (var nee : unit) : (int * string) is block { +// var acc : int := 0 ; +// var st : string := "" ; +// var mymap : map(string,int) := map "one" -> 1 ; "two" -> 2 ; "three" -> 3 end ; +// for k -> v : (string * int) in map mymap +// begin +// acc := acc + v ; +// st := k^st ; +// end +// } with (acc, st) + function dummy (const n : nat) : nat is block { while (False) block { skip } } with n