From 98d49959b7f6cc2ac5c1f6a9fcb76fa58138d087 Mon Sep 17 00:00:00 2001 From: galfour Date: Sat, 9 May 2020 13:21:02 +0200 Subject: [PATCH] add a pass to recompute environments --- src/passes/9-self_ast_typed/dune | 1 + .../9-self_ast_typed/recompute_environment.ml | 168 ++++++++++++++++++ src/passes/9-self_ast_typed/self_ast_typed.ml | 6 +- src/stages/4-ast_typed/combinators.ml | 28 ++- src/stages/4-ast_typed/combinators.mli | 4 + src/stages/5-mini_c/environment.mli | 2 +- 6 files changed, 200 insertions(+), 9 deletions(-) create mode 100644 src/passes/9-self_ast_typed/recompute_environment.ml diff --git a/src/passes/9-self_ast_typed/dune b/src/passes/9-self_ast_typed/dune index 0fc22a1d3..3f00581ac 100644 --- a/src/passes/9-self_ast_typed/dune +++ b/src/passes/9-self_ast_typed/dune @@ -4,6 +4,7 @@ (libraries simple-utils ast_typed + environment ) (preprocess (pps ppx_let bisect_ppx --conditional) diff --git a/src/passes/9-self_ast_typed/recompute_environment.ml b/src/passes/9-self_ast_typed/recompute_environment.ml new file mode 100644 index 000000000..417bbd059 --- /dev/null +++ b/src/passes/9-self_ast_typed/recompute_environment.ml @@ -0,0 +1,168 @@ +open Ast_typed + +(* + During the modifications of the passes on `Ast_typed`, the binding + environments are not kept in sync. To palliate this, this module + recomputes them from scratch. +*) + +(* + This module is very coupled to `typer.ml`. Given environments are + not used until the next pass, it makes sense to split this into + its own separate pass. This pass would go from `Ast_typed` without + environments to `Ast_typed` with embedded environments. +*) + +(* + BAD! + This representation a quadratic amount of space. As environments are + linear in the size of the program, and there is a linear number of them. +*) + +let rec expression : environment -> expression -> expression = fun env expr -> + (* Standard helper functions to help with the fold *) + let return ?(env' = env) content = { + expr with + environment = env' ; + expression_content = content ; + } in + let return_id = return expr.expression_content in + let self ?(env' = env) x = expression env' x in + let self_list lst = List.map self lst in + let self_2 a b = self a , self b in + let self_lmap lm = LMap.map self lm in + let self_cases cs = cases env cs in + match expr.expression_content with + | E_lambda c -> ( + let (t_binder , _) = Combinators.get_t_function_exn expr.type_expression in + let env' = Environment.add_ez_binder c.binder t_binder env in + let result = self ~env' c.result in + return @@ E_lambda { c with result } + ) + | E_let_in c -> ( + let env' = Environment.add_ez_declaration c.let_binder c.rhs env in + let let_result = self ~env' c.let_result in + let rhs = self c.rhs in + return @@ E_let_in { c with rhs ; let_result } + ) + (* rec fun_name binder -> result *) + | E_recursive c -> ( + let env_fun_name = Environment.add_ez_binder c.fun_name c.fun_type env in + let (t_binder , _) = Combinators.get_t_function_exn c.fun_type in + let env_binder = Environment.add_ez_binder c.lambda.binder t_binder env_fun_name in + let result = self ~env':env_binder c.lambda.result in + let lambda = { c.lambda with result } in + return @@ E_recursive { c with lambda } + ) + (* All the following cases are administrative *) + | E_literal _ -> return_id + | E_variable _ -> return_id + | E_constant c -> ( + let arguments = self_list c.arguments in + return @@ E_constant { c with arguments } + ) + | E_application c -> ( + let (lamb , args) = self_2 c.lamb c.args in + return @@ E_application { lamb ; args } + ) + | E_constructor c -> ( + let element = self c.element in + return @@ E_constructor { c with element } + ) + | E_record c -> ( + let c' = self_lmap c in + return @@ E_record c' + ) + | E_record_accessor c -> ( + let record = self c.record in + return @@ E_record_accessor { c with record } + ) + | E_record_update c -> ( + let (record , update) = self_2 c.record c.update in + return @@ E_record_update { c with record ; update } + ) + | E_matching c -> ( + let matchee = self c.matchee in + let cases = self_cases c.cases in + return @@ E_matching { matchee ; cases } + ) + +and cases : environment -> matching_expr -> matching_expr = fun env cs -> + let return x = x in + let self ?(env' = env) x = expression env' x in + match cs with + | Match_list c -> ( + let match_nil = self c.match_nil in + let match_cons = + let mc = c.match_cons in + let env_hd = Environment.add_ez_binder mc.hd mc.tv env in + let env_tl = Environment.add_ez_binder mc.tl (t_list mc.tv ()) env_hd in + let body = self ~env':env_tl mc.body in + { mc with body } + in + return @@ Match_list { match_nil ; match_cons } + ) + | Match_option c -> ( + let match_none = self c.match_none in + let match_some = + let ms = c.match_some in + let env' = Environment.add_ez_binder ms.opt ms.tv env in + let body = self ~env' ms.body in + { ms with body } + in + return @@ Match_option { match_none ; match_some } + ) + | Match_tuple c -> ( + let var_tvs = + try ( + List.combine c.vars c.tvs + ) with _ -> raise (Failure ("Internal error: broken invariant at " ^ __LOC__)) + in + let env' = + let aux prev (var , tv) = + Environment.add_ez_binder var tv prev + in + List.fold_left aux env var_tvs + in + let body = self ~env' c.body in + return @@ Match_tuple { c with body } + ) + | Match_variant c -> ( + let variant_type = Combinators.get_t_sum_exn c.tv in + let cases = + let aux (c : matching_content_case) = + let case = + try ( + CMap.find c.constructor variant_type + ) with _ -> raise (Failure ("Internal error: broken invariant at " ^ __LOC__)) + in + let env' = Environment.add_ez_binder c.pattern case.ctor_type env in + let body = self ~env' c.body in + { c with body } + in + List.map aux c.cases + in + return @@ Match_variant { c with cases } + ) + +let program : environment -> program -> program = fun init_env prog -> + (* + BAD + We take the old type environment and add it to the current value environment + because type declarations are removed in the typer. They should be added back. + *) + let merge old_env re_env = { + expression_environment = re_env.expression_environment ; + type_environment = old_env.type_environment ; + } in + let aux (pre_env , rev_decls) decl_wrapped = + let (Declaration_constant c) = Location.unwrap decl_wrapped in + let expr = expression pre_env c.expr in + let post_env = Environment.add_ez_declaration c.binder c.expr pre_env in + let post_env' = merge c.post_env post_env in + let wrap_content = Declaration_constant { c with expr ; post_env = post_env' } in + let decl_wrapped' = { decl_wrapped with wrap_content } in + (post_env , decl_wrapped' :: rev_decls) + in + let (_last_env , rev_decls) = List.fold_left aux (init_env , []) prog in + List.rev rev_decls diff --git a/src/passes/9-self_ast_typed/self_ast_typed.ml b/src/passes/9-self_ast_typed/self_ast_typed.ml index fc9d27a5c..77b50ce9c 100644 --- a/src/passes/9-self_ast_typed/self_ast_typed.ml +++ b/src/passes/9-self_ast_typed/self_ast_typed.ml @@ -10,9 +10,11 @@ let contract_passes = [ No_nested_big_map.self_typing ; ] -let all_program = +let all_program program = let all_p = List.map Helpers.map_program all_passes in - bind_chain all_p + let%bind program' = bind_chain all_p program in + let program'' = Recompute_environment.program Environment.default program' in + ok program'' let all_expression = let all_p = List.map Helpers.map_expression all_passes in diff --git a/src/stages/4-ast_typed/combinators.ml b/src/stages/4-ast_typed/combinators.ml index eca86d173..e7959cec7 100644 --- a/src/stages/4-ast_typed/combinators.ml +++ b/src/stages/4-ast_typed/combinators.ml @@ -174,9 +174,17 @@ let get_t_pair (t:type_expression) : (type_expression * type_expression) result ok List.(nth lst 0 , nth lst 1) | _ -> fail @@ Errors.not_a_x_type "pair (tuple with two elements)" t () -let get_t_function (t:type_expression) : (type_expression * type_expression) result = match t.type_content with - | T_arrow {type1;type2} -> ok (type1,type2) - | _ -> simple_fail "not a function" +let get_t_function_opt (t:type_expression) : (type_expression * type_expression) option = match t.type_content with + | T_arrow {type1;type2} -> Some (type1,type2) + | _ -> None + +let get_t_function t = + trace_option (Errors.not_a_x_type "function" t ()) @@ + get_t_function_opt t + +let get_t_function_exn t = match get_t_function_opt t with + | Some x -> x + | None -> raise (Failure ("Internal error: broken invariant at " ^ __LOC__)) let get_t_function_full (t:type_expression) : (type_expression * type_expression) result = let%bind _ = get_t_function t in @@ -190,9 +198,17 @@ let get_t_function_full (t:type_expression) : (type_expression * type_expression let input = List.map (fun (l,t) -> (l,{field_type = t ; michelson_annotation = None ; field_decl_pos = 0})) input in ok @@ (t_record (LMap.of_list input) (),output) -let get_t_sum (t:type_expression) : ctor_content constructor_map result = match t.type_content with - | T_sum m -> ok m - | _ -> fail @@ Errors.not_a_x_type "sum" t () +let get_t_sum_opt (t:type_expression) : ctor_content constructor_map option = match t.type_content with + | T_sum m -> Some m + | _ -> None + +let get_t_sum t = match get_t_sum_opt t with + | Some m -> ok m + | None -> fail @@ Errors.not_a_x_type "sum" t () + +let get_t_sum_exn t = match get_t_sum_opt t with + | Some m -> m + | None -> raise (Failure ("Internal error: broken invariant at " ^ __LOC__)) let get_t_record (t:type_expression) : field_content label_map result = match t.type_content with | T_record m -> ok m diff --git a/src/stages/4-ast_typed/combinators.mli b/src/stages/4-ast_typed/combinators.mli index e3ce6a156..7568455d5 100644 --- a/src/stages/4-ast_typed/combinators.mli +++ b/src/stages/4-ast_typed/combinators.mli @@ -63,8 +63,12 @@ val get_t_key_hash : type_expression -> unit result val get_t_tuple : type_expression -> type_expression list result val get_t_pair : type_expression -> ( type_expression * type_expression ) result val get_t_function : type_expression -> ( type_expression * type_expression ) result +val get_t_function_opt : type_expression -> ( type_expression * type_expression ) option +val get_t_function_exn : type_expression -> ( type_expression * type_expression ) val get_t_function_full : type_expression -> ( type_expression * type_expression ) result val get_t_sum : type_expression -> ctor_content constructor_map result +val get_t_sum_opt : type_expression -> ctor_content constructor_map option +val get_t_sum_exn : type_expression -> ctor_content constructor_map val get_t_record : type_expression -> field_content label_map result val get_t_map : type_expression -> ( type_expression * type_expression ) result val get_t_big_map : type_expression -> ( type_expression * type_expression ) result diff --git a/src/stages/5-mini_c/environment.mli b/src/stages/5-mini_c/environment.mli index 231925b97..51191df32 100644 --- a/src/stages/5-mini_c/environment.mli +++ b/src/stages/5-mini_c/environment.mli @@ -19,7 +19,7 @@ module Environment : sig val to_list : t -> element list val get_names : t -> expression_variable list val remove : int -> t -> t - val select : ?rev:bool -> ?keep:bool -> expression_variable list -> t -> t + (* val select : ?rev:bool -> ?keep:bool -> expression_variable list -> t -> t *) (* val fold : ('a -> element -> 'a ) -> 'a -> t -> 'a val filter : ( element -> bool ) -> t -> t