add a pass to recompute environments

This commit is contained in:
galfour 2020-05-09 13:21:02 +02:00
parent c307f251c8
commit 98d49959b7
No known key found for this signature in database
GPG Key ID: 27289510ECBDEF5B
6 changed files with 200 additions and 9 deletions

View File

@ -4,6 +4,7 @@
(libraries
simple-utils
ast_typed
environment
)
(preprocess
(pps ppx_let bisect_ppx --conditional)

View File

@ -0,0 +1,168 @@
open Ast_typed
(*
During the modifications of the passes on `Ast_typed`, the binding
environments are not kept in sync. To palliate this, this module
recomputes them from scratch.
*)
(*
This module is very coupled to `typer.ml`. Given environments are
not used until the next pass, it makes sense to split this into
its own separate pass. This pass would go from `Ast_typed` without
environments to `Ast_typed` with embedded environments.
*)
(*
BAD!
This representation a quadratic amount of space. As environments are
linear in the size of the program, and there is a linear number of them.
*)
let rec expression : environment -> expression -> expression = fun env expr ->
(* Standard helper functions to help with the fold *)
let return ?(env' = env) content = {
expr with
environment = env' ;
expression_content = content ;
} in
let return_id = return expr.expression_content in
let self ?(env' = env) x = expression env' x in
let self_list lst = List.map self lst in
let self_2 a b = self a , self b in
let self_lmap lm = LMap.map self lm in
let self_cases cs = cases env cs in
match expr.expression_content with
| E_lambda c -> (
let (t_binder , _) = Combinators.get_t_function_exn expr.type_expression in
let env' = Environment.add_ez_binder c.binder t_binder env in
let result = self ~env' c.result in
return @@ E_lambda { c with result }
)
| E_let_in c -> (
let env' = Environment.add_ez_declaration c.let_binder c.rhs env in
let let_result = self ~env' c.let_result in
let rhs = self c.rhs in
return @@ E_let_in { c with rhs ; let_result }
)
(* rec fun_name binder -> result *)
| E_recursive c -> (
let env_fun_name = Environment.add_ez_binder c.fun_name c.fun_type env in
let (t_binder , _) = Combinators.get_t_function_exn c.fun_type in
let env_binder = Environment.add_ez_binder c.lambda.binder t_binder env_fun_name in
let result = self ~env':env_binder c.lambda.result in
let lambda = { c.lambda with result } in
return @@ E_recursive { c with lambda }
)
(* All the following cases are administrative *)
| E_literal _ -> return_id
| E_variable _ -> return_id
| E_constant c -> (
let arguments = self_list c.arguments in
return @@ E_constant { c with arguments }
)
| E_application c -> (
let (lamb , args) = self_2 c.lamb c.args in
return @@ E_application { lamb ; args }
)
| E_constructor c -> (
let element = self c.element in
return @@ E_constructor { c with element }
)
| E_record c -> (
let c' = self_lmap c in
return @@ E_record c'
)
| E_record_accessor c -> (
let record = self c.record in
return @@ E_record_accessor { c with record }
)
| E_record_update c -> (
let (record , update) = self_2 c.record c.update in
return @@ E_record_update { c with record ; update }
)
| E_matching c -> (
let matchee = self c.matchee in
let cases = self_cases c.cases in
return @@ E_matching { matchee ; cases }
)
and cases : environment -> matching_expr -> matching_expr = fun env cs ->
let return x = x in
let self ?(env' = env) x = expression env' x in
match cs with
| Match_list c -> (
let match_nil = self c.match_nil in
let match_cons =
let mc = c.match_cons in
let env_hd = Environment.add_ez_binder mc.hd mc.tv env in
let env_tl = Environment.add_ez_binder mc.tl (t_list mc.tv ()) env_hd in
let body = self ~env':env_tl mc.body in
{ mc with body }
in
return @@ Match_list { match_nil ; match_cons }
)
| Match_option c -> (
let match_none = self c.match_none in
let match_some =
let ms = c.match_some in
let env' = Environment.add_ez_binder ms.opt ms.tv env in
let body = self ~env' ms.body in
{ ms with body }
in
return @@ Match_option { match_none ; match_some }
)
| Match_tuple c -> (
let var_tvs =
try (
List.combine c.vars c.tvs
) with _ -> raise (Failure ("Internal error: broken invariant at " ^ __LOC__))
in
let env' =
let aux prev (var , tv) =
Environment.add_ez_binder var tv prev
in
List.fold_left aux env var_tvs
in
let body = self ~env' c.body in
return @@ Match_tuple { c with body }
)
| Match_variant c -> (
let variant_type = Combinators.get_t_sum_exn c.tv in
let cases =
let aux (c : matching_content_case) =
let case =
try (
CMap.find c.constructor variant_type
) with _ -> raise (Failure ("Internal error: broken invariant at " ^ __LOC__))
in
let env' = Environment.add_ez_binder c.pattern case.ctor_type env in
let body = self ~env' c.body in
{ c with body }
in
List.map aux c.cases
in
return @@ Match_variant { c with cases }
)
let program : environment -> program -> program = fun init_env prog ->
(*
BAD
We take the old type environment and add it to the current value environment
because type declarations are removed in the typer. They should be added back.
*)
let merge old_env re_env = {
expression_environment = re_env.expression_environment ;
type_environment = old_env.type_environment ;
} in
let aux (pre_env , rev_decls) decl_wrapped =
let (Declaration_constant c) = Location.unwrap decl_wrapped in
let expr = expression pre_env c.expr in
let post_env = Environment.add_ez_declaration c.binder c.expr pre_env in
let post_env' = merge c.post_env post_env in
let wrap_content = Declaration_constant { c with expr ; post_env = post_env' } in
let decl_wrapped' = { decl_wrapped with wrap_content } in
(post_env , decl_wrapped' :: rev_decls)
in
let (_last_env , rev_decls) = List.fold_left aux (init_env , []) prog in
List.rev rev_decls

View File

@ -10,9 +10,11 @@ let contract_passes = [
No_nested_big_map.self_typing ;
]
let all_program =
let all_program program =
let all_p = List.map Helpers.map_program all_passes in
bind_chain all_p
let%bind program' = bind_chain all_p program in
let program'' = Recompute_environment.program Environment.default program' in
ok program''
let all_expression =
let all_p = List.map Helpers.map_expression all_passes in

View File

@ -174,9 +174,17 @@ let get_t_pair (t:type_expression) : (type_expression * type_expression) result
ok List.(nth lst 0 , nth lst 1)
| _ -> fail @@ Errors.not_a_x_type "pair (tuple with two elements)" t ()
let get_t_function (t:type_expression) : (type_expression * type_expression) result = match t.type_content with
| T_arrow {type1;type2} -> ok (type1,type2)
| _ -> simple_fail "not a function"
let get_t_function_opt (t:type_expression) : (type_expression * type_expression) option = match t.type_content with
| T_arrow {type1;type2} -> Some (type1,type2)
| _ -> None
let get_t_function t =
trace_option (Errors.not_a_x_type "function" t ()) @@
get_t_function_opt t
let get_t_function_exn t = match get_t_function_opt t with
| Some x -> x
| None -> raise (Failure ("Internal error: broken invariant at " ^ __LOC__))
let get_t_function_full (t:type_expression) : (type_expression * type_expression) result =
let%bind _ = get_t_function t in
@ -190,9 +198,17 @@ let get_t_function_full (t:type_expression) : (type_expression * type_expression
let input = List.map (fun (l,t) -> (l,{field_type = t ; michelson_annotation = None ; field_decl_pos = 0})) input in
ok @@ (t_record (LMap.of_list input) (),output)
let get_t_sum (t:type_expression) : ctor_content constructor_map result = match t.type_content with
| T_sum m -> ok m
| _ -> fail @@ Errors.not_a_x_type "sum" t ()
let get_t_sum_opt (t:type_expression) : ctor_content constructor_map option = match t.type_content with
| T_sum m -> Some m
| _ -> None
let get_t_sum t = match get_t_sum_opt t with
| Some m -> ok m
| None -> fail @@ Errors.not_a_x_type "sum" t ()
let get_t_sum_exn t = match get_t_sum_opt t with
| Some m -> m
| None -> raise (Failure ("Internal error: broken invariant at " ^ __LOC__))
let get_t_record (t:type_expression) : field_content label_map result = match t.type_content with
| T_record m -> ok m

View File

@ -63,8 +63,12 @@ val get_t_key_hash : type_expression -> unit result
val get_t_tuple : type_expression -> type_expression list result
val get_t_pair : type_expression -> ( type_expression * type_expression ) result
val get_t_function : type_expression -> ( type_expression * type_expression ) result
val get_t_function_opt : type_expression -> ( type_expression * type_expression ) option
val get_t_function_exn : type_expression -> ( type_expression * type_expression )
val get_t_function_full : type_expression -> ( type_expression * type_expression ) result
val get_t_sum : type_expression -> ctor_content constructor_map result
val get_t_sum_opt : type_expression -> ctor_content constructor_map option
val get_t_sum_exn : type_expression -> ctor_content constructor_map
val get_t_record : type_expression -> field_content label_map result
val get_t_map : type_expression -> ( type_expression * type_expression ) result
val get_t_big_map : type_expression -> ( type_expression * type_expression ) result

View File

@ -19,7 +19,7 @@ module Environment : sig
val to_list : t -> element list
val get_names : t -> expression_variable list
val remove : int -> t -> t
val select : ?rev:bool -> ?keep:bool -> expression_variable list -> t -> t
(* val select : ?rev:bool -> ?keep:bool -> expression_variable list -> t -> t *)
(*
val fold : ('a -> element -> 'a ) -> 'a -> t -> 'a
val filter : ( element -> bool ) -> t -> t