diff --git a/src/passes/2-simplify/pascaligo.ml b/src/passes/2-simplify/pascaligo.ml index 63f9d8e70..06a428230 100644 --- a/src/passes/2-simplify/pascaligo.ml +++ b/src/passes/2-simplify/pascaligo.ml @@ -41,13 +41,28 @@ let detect_free_variables (for_body : expression) (local_decl_names : string lis | E_assign ( name , _ , _ ) -> if is_compiler_generated name then ok prev else ok (name::prev) + | E_constant (n, [a;b]) + when n="OR" || n="AND" || n="LT" || n="GT" || + n="LE" || n="GE" || n="EQ" || n="NEQ" -> ( + match (a.expression,b.expression) with + | E_variable na , E_variable nb -> + let ret = [] in + let ret = if not (is_compiler_generated na) then + na::ret else ret in + let ret = if not (is_compiler_generated nb) then + nb::ret else ret in + ok (ret@prev) + | E_variable n , _ + | _ , E_variable n -> + if not (is_compiler_generated n) then + ok (n::prev) else ok prev + | _ -> ok prev) | _ -> ok prev ) [] for_body in ok @@ SSet.elements @@ SSet.diff (SSet.of_list captured_names) (SSet.of_list local_decl_names) - module Errors = struct let unsupported_cst_constr p = let title () = "constant constructor" in @@ -1125,11 +1140,14 @@ and simpl_for_collect : Raw.for_collect -> (_ -> expression result) result = fun let elt_v_name = match fc.bind_to with | Some v -> "#COMPILER#elt"^(snd v).value | None -> "#COMPILER#elt_unused" in + let element_names = ok @@ match fc.bind_to with + | Some v -> [fc.var.value;(snd v).value] + | None -> [fc.var.value] in (* STEP 1 *) let%bind for_body = simpl_block fc.block.value in let%bind for_body = for_body None in (* STEP 2 *) - let%bind local_decl_name_list = detect_local_declarations for_body in + let%bind local_decl_name_list = bind_concat (detect_local_declarations for_body) element_names in let%bind captured_name_list = detect_free_variables for_body local_decl_name_list in (* STEP 3 *) let add_to_record (prev: expression type_name_map) (captured_name: string) = diff --git a/src/test/contracts/loop.ligo b/src/test/contracts/loop.ligo index 9027dc1cb..50e87da04 100644 --- a/src/test/contracts/loop.ligo +++ b/src/test/contracts/loop.ligo @@ -178,3 +178,16 @@ function nested_for_collection_local_var (var nee : unit) : (int*string) is bloc function dummy (const n : nat) : nat is block { while False block { skip } } with n + +function inner_capture_in_conditional_block (var nee : unit) : bool*int is block { + var count : int := 1 ; + var ret : bool := False ; + var mylist : list(int) := list 1 ; 2 ; 3 end ; + for it1 in list mylist block { + for it2 in list mylist block { + if count = it2 then ret := not (ret) + else skip; + }; + count := count + 1; + } +} with (ret,count) \ No newline at end of file diff --git a/src/test/integration_tests.ml b/src/test/integration_tests.ml index b509a845f..2a7370845 100644 --- a/src/test/integration_tests.ml +++ b/src/test/integration_tests.ml @@ -845,6 +845,9 @@ let loop () : unit result = let expected = e_pair (e_int 24) (e_string "123123123") in expect_eq program "nested_for_collection_local_var" input expected in + let%bind () = + let expected = e_pair (e_bool true) (e_int 4) in + expect_eq program "inner_capture_in_conditional_block" input expected in let%bind () = let ez lst = let open Ast_simplified.Combinators in diff --git a/vendors/ligo-utils/simple-utils/trace.ml b/vendors/ligo-utils/simple-utils/trace.ml index b15a05e1d..04f8b511d 100644 --- a/vendors/ligo-utils/simple-utils/trace.ml +++ b/vendors/ligo-utils/simple-utils/trace.ml @@ -567,6 +567,11 @@ let bind_fold_smap f init (smap : _ X_map.String.t) = let bind_map_smap f smap = bind_smap (X_map.String.map f smap) +let bind_concat (l1:'a list result) (l2: 'a list result) = + let%bind l1' = l1 in + let%bind l2' = l2 in + ok @@ (l1' @ l2') + let bind_map_list f lst = bind_list (List.map f lst) let rec bind_map_list_seq f lst = match lst with | [] -> ok []