From f0655eab281a3057dbaaeb98d885f7284df62f7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Lesenechal?= Date: Tue, 19 Nov 2019 13:25:48 +0000 Subject: [PATCH] Pascaligo for collection loops: take locally declared variable into account --- src/passes/2-simplify/pascaligo.ml | 65 ++++++++++++++++++++++-------- src/test/contracts/loop.ligo | 19 +++++++++ src/test/integration_tests.ml | 4 ++ 3 files changed, 71 insertions(+), 17 deletions(-) diff --git a/src/passes/2-simplify/pascaligo.ml b/src/passes/2-simplify/pascaligo.ml index 799b073cc..a78b11cd7 100644 --- a/src/passes/2-simplify/pascaligo.ml +++ b/src/passes/2-simplify/pascaligo.ml @@ -3,6 +3,7 @@ open Ast_simplified module Raw = Parser.Pascaligo.AST module SMap = Map.String +module SSet = Set.Make (String) open Combinators @@ -14,6 +15,39 @@ let pseq_to_list = function let get_value : 'a Raw.reg -> 'a = fun x -> x.value let is_compiler_generated = fun name -> String.contains name '#' +let detect_local_declarations (for_body : expression) = + let%bind aux = Self_ast_simplified.fold_expression + (fun (nlist, cur_loop : type_name list * bool) (ass_exp : expression) -> + if cur_loop then + match ass_exp.expression with + | E_let_in {binder;rhs = _;result = _} -> + let (name,_) = binder in + ok (name::nlist, cur_loop) + | E_constant ("MAP_FOLD", _) + | E_constant ("SET_FOLD", _) + | E_constant ("LIST_FOLD", _) -> ok @@ (nlist, false) + | _ -> ok (nlist, cur_loop) + else + ok @@ (nlist, cur_loop) + ) + ([], true) + for_body in + ok @@ fst aux + +let detect_free_variables (for_body : expression) (local_decl_names : string list) = + let%bind captured_names = Self_ast_simplified.fold_expression + (fun (prev : type_name list) (ass_exp : expression) -> + match ass_exp.expression with + | E_assign ( name , _ , _ ) -> + if is_compiler_generated name then ok prev + else ok (name::prev) + | _ -> ok prev ) + [] + for_body in + ok @@ SSet.elements + @@ SSet.diff (SSet.of_list captured_names) (SSet.of_list local_decl_names) + + module Errors = struct let unsupported_cst_constr p = let title () = "constant constructor" in @@ -1033,7 +1067,8 @@ and simpl_for_int : Raw.for_int -> (_ -> expression result) result = fun fi -> 2) Detect the free variables and build a list of their names (myint and myst in the previous example) - Free variables are simply variables being assigned. + Free variables are simply variables being assigned but not defined + locally. Note: In the case of a nested loops, assignements to a compiler generated value (#COMPILER#acc) correspond to variables that were already renamed in the inner loop. @@ -1094,15 +1129,8 @@ and simpl_for_collect : Raw.for_collect -> (_ -> expression result) result = fun let%bind for_body = simpl_block fc.block.value in let%bind for_body = for_body None in (* STEP 2 *) - let%bind captured_name_list = Self_ast_simplified.fold_expression - (fun (prev : type_name list) (ass_exp : expression) -> - match ass_exp.expression with - | E_assign ( name , _ , _ ) -> - if is_compiler_generated name then ok prev - else ok (name::prev) - | _ -> ok prev ) - [] - for_body in + let%bind local_decl_name_list = detect_local_declarations for_body in + let%bind captured_name_list = detect_free_variables for_body local_decl_name_list in (* STEP 3 *) let add_to_record (prev: expression type_name_map) (captured_name: string) = SMap.add captured_name (e_variable captured_name) prev in @@ -1112,13 +1140,16 @@ and simpl_for_collect : Raw.for_collect -> (_ -> expression result) result = fun match exp.expression with (* replace references to fold accumulator as lhs *) | E_assign ( name , path , expr ) -> ( - let path' = List.filter - ( fun el -> - match el with - | Access_record name -> not @@ is_compiler_generated name - | _ -> true ) - ((Access_record name)::path) in - ok @@ e_assign "#COMPILER#acc" path' expr) + if (List.mem name local_decl_name_list ) then + ok @@ exp + else + let path' = List.filter + ( fun el -> + match el with + | Access_record name -> not @@ is_compiler_generated name + | _ -> true ) + ((Access_record name)::path) in + ok @@ e_assign "#COMPILER#acc" path' expr ) | E_variable name -> ( if (List.mem name captured_name_list) then (* replace references to fold accumulator as rhs *) diff --git a/src/test/contracts/loop.ligo b/src/test/contracts/loop.ligo index e71d2a372..9027dc1cb 100644 --- a/src/test/contracts/loop.ligo +++ b/src/test/contracts/loop.ligo @@ -156,6 +156,25 @@ function nested_for_collection (var nee : unit) : (int*string) is block { end } with (myint,mystoo) +function nested_for_collection_local_var (var nee : unit) : (int*string) is block { + var myint : int := 0; + var myst : string := ""; + var mylist : list(int) := list 1 ; 2 ; 3 end ; + + for i in list mylist + begin + var myst_loc : string := "" ; + myint := myint + i ; + var myset : set(string) := set "1" ; "2" ; "3" end ; + for st in set myset + begin + myint := myint + i ; + myst_loc := myst_loc ^ st ; + end; + myst := myst_loc ^ myst ; + end +} with (myint,myst) + function dummy (const n : nat) : nat is block { while False block { skip } } with n diff --git a/src/test/integration_tests.ml b/src/test/integration_tests.ml index 7088cdb39..16bb7a666 100644 --- a/src/test/integration_tests.ml +++ b/src/test/integration_tests.ml @@ -836,6 +836,10 @@ let loop () : unit result = let expected = e_pair (e_int 24) (e_string "1 one,two 2 one,two 3 one,two 1 one,two 2 one,two 3 one,two 1 one,two 2 one,two 3 one,two ") in expect_eq program "nested_for_collection" input expected in + let%bind () = + let expected = e_pair (e_int 24) + (e_string "123123123") in + expect_eq program "nested_for_collection_local_var" input expected in let%bind () = let ez lst = let open Ast_simplified.Combinators in