Pascaligo for collection loops: take locally declared variable into account

This commit is contained in:
Rémi Lesenechal 2019-11-19 13:25:48 +00:00
parent bbcacc253b
commit f0655eab28
3 changed files with 71 additions and 17 deletions

View File

@ -3,6 +3,7 @@ open Ast_simplified
module Raw = Parser.Pascaligo.AST
module SMap = Map.String
module SSet = Set.Make (String)
open Combinators
@ -14,6 +15,39 @@ let pseq_to_list = function
let get_value : 'a Raw.reg -> 'a = fun x -> x.value
let is_compiler_generated = fun name -> String.contains name '#'
let detect_local_declarations (for_body : expression) =
let%bind aux = Self_ast_simplified.fold_expression
(fun (nlist, cur_loop : type_name list * bool) (ass_exp : expression) ->
if cur_loop then
match ass_exp.expression with
| E_let_in {binder;rhs = _;result = _} ->
let (name,_) = binder in
ok (name::nlist, cur_loop)
| E_constant ("MAP_FOLD", _)
| E_constant ("SET_FOLD", _)
| E_constant ("LIST_FOLD", _) -> ok @@ (nlist, false)
| _ -> ok (nlist, cur_loop)
else
ok @@ (nlist, cur_loop)
)
([], true)
for_body in
ok @@ fst aux
let detect_free_variables (for_body : expression) (local_decl_names : string list) =
let%bind captured_names = Self_ast_simplified.fold_expression
(fun (prev : type_name list) (ass_exp : expression) ->
match ass_exp.expression with
| E_assign ( name , _ , _ ) ->
if is_compiler_generated name then ok prev
else ok (name::prev)
| _ -> ok prev )
[]
for_body in
ok @@ SSet.elements
@@ SSet.diff (SSet.of_list captured_names) (SSet.of_list local_decl_names)
module Errors = struct
let unsupported_cst_constr p =
let title () = "constant constructor" in
@ -1033,7 +1067,8 @@ and simpl_for_int : Raw.for_int -> (_ -> expression result) result = fun fi ->
2) Detect the free variables and build a list of their names
(myint and myst in the previous example)
Free variables are simply variables being assigned.
Free variables are simply variables being assigned but not defined
locally.
Note: In the case of a nested loops, assignements to a compiler
generated value (#COMPILER#acc) correspond to variables
that were already renamed in the inner loop.
@ -1094,15 +1129,8 @@ and simpl_for_collect : Raw.for_collect -> (_ -> expression result) result = fun
let%bind for_body = simpl_block fc.block.value in
let%bind for_body = for_body None in
(* STEP 2 *)
let%bind captured_name_list = Self_ast_simplified.fold_expression
(fun (prev : type_name list) (ass_exp : expression) ->
match ass_exp.expression with
| E_assign ( name , _ , _ ) ->
if is_compiler_generated name then ok prev
else ok (name::prev)
| _ -> ok prev )
[]
for_body in
let%bind local_decl_name_list = detect_local_declarations for_body in
let%bind captured_name_list = detect_free_variables for_body local_decl_name_list in
(* STEP 3 *)
let add_to_record (prev: expression type_name_map) (captured_name: string) =
SMap.add captured_name (e_variable captured_name) prev in
@ -1112,13 +1140,16 @@ and simpl_for_collect : Raw.for_collect -> (_ -> expression result) result = fun
match exp.expression with
(* replace references to fold accumulator as lhs *)
| E_assign ( name , path , expr ) -> (
let path' = List.filter
( fun el ->
match el with
| Access_record name -> not @@ is_compiler_generated name
| _ -> true )
((Access_record name)::path) in
ok @@ e_assign "#COMPILER#acc" path' expr)
if (List.mem name local_decl_name_list ) then
ok @@ exp
else
let path' = List.filter
( fun el ->
match el with
| Access_record name -> not @@ is_compiler_generated name
| _ -> true )
((Access_record name)::path) in
ok @@ e_assign "#COMPILER#acc" path' expr )
| E_variable name -> (
if (List.mem name captured_name_list) then
(* replace references to fold accumulator as rhs *)

View File

@ -156,6 +156,25 @@ function nested_for_collection (var nee : unit) : (int*string) is block {
end
} with (myint,mystoo)
function nested_for_collection_local_var (var nee : unit) : (int*string) is block {
var myint : int := 0;
var myst : string := "";
var mylist : list(int) := list 1 ; 2 ; 3 end ;
for i in list mylist
begin
var myst_loc : string := "" ;
myint := myint + i ;
var myset : set(string) := set "1" ; "2" ; "3" end ;
for st in set myset
begin
myint := myint + i ;
myst_loc := myst_loc ^ st ;
end;
myst := myst_loc ^ myst ;
end
} with (myint,myst)
function dummy (const n : nat) : nat is block {
while False block { skip }
} with n

View File

@ -836,6 +836,10 @@ let loop () : unit result =
let expected = e_pair (e_int 24)
(e_string "1 one,two 2 one,two 3 one,two 1 one,two 2 one,two 3 one,two 1 one,two 2 one,two 3 one,two ") in
expect_eq program "nested_for_collection" input expected in
let%bind () =
let expected = e_pair (e_int 24)
(e_string "123123123") in
expect_eq program "nested_for_collection_local_var" input expected in
let%bind () =
let ez lst =
let open Ast_simplified.Combinators in