typer: do multiple substitutions at once (pass a sort of map from free variables to their substitution)

This commit is contained in:
Suzanne Dupéron 2019-12-06 17:48:57 +01:00 committed by Suzanne Dupéron
parent 688a636251
commit 93d16b4b6a
4 changed files with 166 additions and 144 deletions

View File

@ -352,7 +352,7 @@ end
module TypeVariable = module TypeVariable =
struct struct
type t = Core.type_variable type t = Core.type_variable
let compare a b= Var.compare a b let compare a b = Var.compare a b
let to_string = (fun s -> Format.asprintf "%a" Var.pp s) let to_string = (fun s -> Format.asprintf "%a" Var.pp s)
end end

View File

@ -961,15 +961,23 @@ let type_program_returns_state (p:I.program) : (environment * Solver.state * O.p
let type_program (p : I.program) : (O.program * Solver.state) result = let type_program (p : I.program) : (O.program * Solver.state) result =
let%bind (env, state, program) = type_program_returns_state p in let%bind (env, state, program) = type_program_returns_state p in
let subst_all = let subst_all =
let aliases = state.structured_dbs.aliases in
let assignments = state.structured_dbs.assignments in let assignments = state.structured_dbs.assignments in
let aux (v : I.type_variable) (expr : Solver.c_constructor_simpl) (p:O.program result) = let substs : variable: I.type_variable -> _ = fun ~variable ->
let%bind p = p in to_option @@
let Solver.{ tv ; c_tag ; tv_list } = expr in let%bind root =
trace_option (simple_error (Format.asprintf "can't find alias root of variable %a" Var.pp variable)) @@
(* TODO: after upgrading UnionFind, this will be an option, not an exception. *)
try Some (Solver.UF.repr variable aliases) with Not_found -> None in
let%bind assignment =
trace_option (simple_error (Format.asprintf "can't find assignment for root %a" Var.pp root)) @@
(Solver.TypeVariableMap.find_opt root assignments) in
let Solver.{ tv ; c_tag ; tv_list } = assignment in
let () = ignore tv (* I think there is an issue where the tv is stored twice (as a key and in the element itself) *) in let () = ignore tv (* I think there is an issue where the tv is stored twice (as a key and in the element itself) *) in
let%bind (expr : O.type_value') = Typesystem.Core.type_expression'_of_simple_c_constant (c_tag , (List.map (fun s -> O.{ type_value' = T_variable s ; simplified = None }) tv_list)) in let%bind (expr : O.type_value') = Typesystem.Core.type_expression'_of_simple_c_constant (c_tag , (List.map (fun s -> O.{ type_value' = T_variable s ; simplified = None }) tv_list)) in
Typesystem.Misc.Substitution.Pattern.program ~p ~v ~expr in ok @@ expr
(* let p = TSMap.bind_fold_Map aux program assignments in *) (* TODO: Module magic: this does not work *) in
let p = Solver.TypeVariableMap.fold aux assignments (ok program) in let p = Typesystem.Misc.Substitution.Pattern.s_program ~substs program in
p in p in
let%bind program = subst_all in let%bind program = subst_all in
let () = ignore env in (* TODO: shouldn't we use the `env` somewhere? *) let () = ignore env in (* TODO: shouldn't we use the `env` somewhere? *)

View File

@ -34,6 +34,12 @@ and annotated_expression = {
location : Location.t ; location : Location.t ;
} }
(* This seems to be used only for top-level declarations, and
represents the name of the top-level binding, and the expression
assigned to it. -- Suzanne.
TODO: if this is correct, then we should inline this in
"declaration" or at least move it close to it. *)
and named_expression = { and named_expression = {
name: expression_variable ; name: expression_variable ;
annotated_expression: ae ; annotated_expression: ae ;

View File

@ -9,113 +9,118 @@ module Substitution = struct
module T = Ast_typed module T = Ast_typed
(* module TSMap = Trace.TMap(String) *) (* module TSMap = Trace.TMap(String) *)
type 'a w = 'a -> 'a result type substs = variable:type_variable -> T.type_value' option (* this string is a type_name or type_variable I think *)
let mk_substs ~v ~expr = (v , expr)
type 'a w = substs:substs -> 'a -> 'a result
let rec rec_yes = true let rec rec_yes = true
and s_environment_element_definition ~v ~expr = function and s_environment_element_definition ~substs = function
| T.ED_binder -> ok @@ T.ED_binder | T.ED_binder -> ok @@ T.ED_binder
| T.ED_declaration (val_, free_variables) -> | T.ED_declaration (val_, free_variables) ->
let%bind val_ = s_annotated_expression ~v ~expr val_ in let%bind val_ = s_annotated_expression ~substs val_ in
let%bind free_variables = bind_map_list (s_variable ~v ~expr) free_variables in let%bind free_variables = bind_map_list (s_variable ~substs) free_variables in
ok @@ T.ED_declaration (val_, free_variables) ok @@ T.ED_declaration (val_, free_variables)
and s_environment ~v ~expr : T.environment w = fun env -> and s_environment : T.environment w = fun ~substs env ->
bind_map_list (fun (variable, T.{ type_value; source_environment; definition }) -> bind_map_list (fun (variable, T.{ type_value; source_environment; definition }) ->
let%bind variable = s_variable ~v ~expr variable in let%bind variable = s_variable ~substs variable in
let%bind type_value = s_type_value ~v ~expr type_value in let%bind type_value = s_type_value ~substs type_value in
let%bind source_environment = s_full_environment ~v ~expr source_environment in let%bind source_environment = s_full_environment ~substs source_environment in
let%bind definition = s_environment_element_definition ~v ~expr definition in let%bind definition = s_environment_element_definition ~substs definition in
ok @@ (variable, T.{ type_value; source_environment; definition })) env ok @@ (variable, T.{ type_value; source_environment; definition })) env
and s_type_environment ~v ~expr : T.type_environment w = fun tenv -> and s_type_environment : T.type_environment w = fun ~substs tenv ->
bind_map_list (fun (type_variable , type_value) -> bind_map_list (fun (type_variable , type_value) ->
let%bind type_variable = s_type_variable ~v ~expr type_variable in let%bind type_variable = s_type_variable ~substs type_variable in
let%bind type_value = s_type_value ~v ~expr type_value in let%bind type_value = s_type_value ~substs type_value in
ok @@ (type_variable , type_value)) tenv ok @@ (type_variable , type_value)) tenv
and s_small_environment ~v ~expr : T.small_environment w = fun (environment, type_environment) -> and s_small_environment : T.small_environment w = fun ~substs (environment, type_environment) ->
let%bind environment = s_environment ~v ~expr environment in let%bind environment = s_environment ~substs environment in
let%bind type_environment = s_type_environment ~v ~expr type_environment in let%bind type_environment = s_type_environment ~substs type_environment in
ok @@ (environment, type_environment) ok @@ (environment, type_environment)
and s_full_environment ~v ~expr : T.full_environment w = fun (a , b) -> and s_full_environment : T.full_environment w = fun ~substs (a , b) ->
let%bind a = s_small_environment ~v ~expr a in let%bind a = s_small_environment ~substs a in
let%bind b = bind_map_list (s_small_environment ~v ~expr) b in let%bind b = bind_map_list (s_small_environment ~substs) b in
ok (a , b) ok (a , b)
and s_variable ~v ~expr : T.expression_variable w = fun var -> and s_variable : T.expression_variable w = fun ~substs var ->
let () = ignore (v, expr) in let () = ignore @@ substs in
ok var ok var
and s_type_variable ~v ~expr : T.type_variable w = fun tvar -> and s_type_variable : T.type_variable w = fun ~substs tvar ->
let _TODO = ignore (v, expr) in let _TODO = ignore @@ substs in
Printf.printf "TODO: subst: unimplemented case s_type_variable"; Printf.printf "TODO: subst: unimplemented case s_type_variable";
ok @@ tvar ok @@ tvar
(* if String.equal tvar v then (* if String.equal tvar v then
* expr * expr
* else * else
* ok tvar *) * ok tvar *)
and s_label ~v ~expr : T.label w = fun l -> and s_label : T.label w = fun ~substs l ->
let () = ignore (v, expr) in let () = ignore @@ substs in
ok l ok l
and s_build_in ~v ~expr : T.constant w = fun b -> and s_build_in : T.constant w = fun ~substs b ->
let () = ignore (v, expr) in let () = ignore @@ substs in
ok b ok b
and s_constructor ~v ~expr : T.constructor w = fun c -> and s_constructor : T.constructor w = fun ~substs c ->
let () = ignore (v, expr) in let () = ignore @@ substs in
ok c ok c
and s_type_name_constant ~v ~expr : T.type_constant w = fun type_name -> and s_type_name_constant : T.type_constant w = fun ~substs type_name ->
(* TODO: we don't need to subst anything, right? *) (* TODO: we don't need to subst anything, right? *)
let () = ignore (v , expr) in let () = ignore @@ substs in
ok @@ type_name ok @@ type_name
and s_type_value' ~v ~expr : T.type_value' w = function and s_type_value' : T.type_value' w = fun ~substs -> function
| T.T_tuple type_value_list -> | T.T_tuple type_value_list ->
let%bind type_value_list = bind_map_list (s_type_value ~v ~expr) type_value_list in let%bind type_value_list = bind_map_list (s_type_value ~substs) type_value_list in
ok @@ T.T_tuple type_value_list ok @@ T.T_tuple type_value_list
| T.T_sum _ -> failwith "TODO: T_sum" | T.T_sum _ -> failwith "TODO: T_sum"
| T.T_record _ -> failwith "TODO: T_record" | T.T_record _ -> failwith "TODO: T_record"
| T.T_constant (type_name) -> | T.T_constant type_name ->
let%bind type_name = s_type_name_constant ~v ~expr type_name in let%bind type_name = s_type_name_constant ~substs type_name in
ok @@ T.T_constant (type_name) ok @@ T.T_constant (type_name)
| T.T_variable variable -> | T.T_variable variable ->
if Var.equal variable v begin
then ok @@ expr match substs ~variable with
else ok @@ T.T_variable variable | Some expr -> s_type_value' ~substs expr (* TODO: is it the right thing to recursively examine this? We mustn't go into an infinite loop. *)
| T.T_operator (type_name_and_args) -> | None -> ok @@ T.T_variable variable
end
| T.T_operator type_name_and_args ->
let bind_map_type_operator = Stage_common.Misc.bind_map_type_operator in (* TODO: write T.Misc.bind_map_type_operator, but it doesn't work *) let bind_map_type_operator = Stage_common.Misc.bind_map_type_operator in (* TODO: write T.Misc.bind_map_type_operator, but it doesn't work *)
let%bind type_name_and_args = bind_map_type_operator (s_type_value ~v ~expr) type_name_and_args in let%bind type_name_and_args = bind_map_type_operator (s_type_value ~substs) type_name_and_args in
ok @@ T.T_operator type_name_and_args ok @@ T.T_operator type_name_and_args
| T.T_arrow _ -> | T.T_arrow _ ->
let _TODO = (v, expr) in let _TODO = substs in
failwith "TODO: T_function" failwith "TODO: T_function"
and s_type_expression' ~v ~expr : _ Ast_simplified.type_expression' w = fun type_expression' -> and s_type_expression' : _ Ast_simplified.type_expression' w = fun ~substs -> function
match type_expression' with | Ast_simplified.T_tuple _ -> failwith "TODO: subst: unimplemented case s_type_expression tuple"
| Ast_simplified.T_tuple _ -> failwith "TODO: subst: unimplemented case s_type_expression tuple" | Ast_simplified.T_sum _ -> failwith "TODO: subst: unimplemented case s_type_expression sum"
| Ast_simplified.T_sum _ -> failwith "TODO: subst: unimplemented case s_type_expression sum" | Ast_simplified.T_record _ -> failwith "TODO: subst: unimplemented case s_type_expression record"
| Ast_simplified.T_record _ -> failwith "TODO: subst: unimplemented case s_type_expression record" | Ast_simplified.T_arrow (_, _) -> failwith "TODO: subst: unimplemented case s_type_expression arrow"
| Ast_simplified.T_arrow (_, _) -> failwith "TODO: subst: unimplemented case s_type_expression arrow" | Ast_simplified.T_variable _ -> failwith "TODO: subst: unimplemented case s_type_expression variable"
| Ast_simplified.T_variable _ -> failwith "TODO: subst: unimplemented case s_type_expression variable" | Ast_simplified.T_operator op ->
| Ast_simplified.T_operator op -> let%bind op =
let%bind op = Stage_common.Misc.bind_map_type_operator (* TODO: write Ast_simplified.Misc.type_operator_name *)
Stage_common.Misc.bind_map_type_operator (* TODO: write Ast_simplified.Misc.type_operator_name *) (s_type_expression ~substs)
(s_type_expression ~v ~expr) op in
op in (* TODO: when we have generalized operators, we might need to subst the operator name itself? *)
ok @@ Ast_simplified.T_operator op ok @@ Ast_simplified.T_operator op
| Ast_simplified.T_constant constant -> | Ast_simplified.T_constant constant ->
ok @@ Ast_simplified.T_constant constant ok @@ Ast_simplified.T_constant constant
and s_type_expression ~v ~expr : Ast_simplified.type_expression w = fun {type_expression'} -> and s_type_expression : Ast_simplified.type_expression w = fun ~substs {type_expression'} ->
let%bind type_expression' = s_type_expression' ~v ~expr type_expression' in let%bind type_expression' = s_type_expression' ~substs type_expression' in
ok @@ Ast_simplified.{type_expression'} ok @@ Ast_simplified.{type_expression'}
and s_type_value ~v ~expr : T.type_value w = fun { type_value'; simplified } -> and s_type_value : T.type_value w = fun ~substs { type_value'; simplified } ->
let%bind type_value' = s_type_value' ~v ~expr type_value' in let%bind type_value' = s_type_value' ~substs type_value' in
let%bind simplified = bind_map_option (s_type_expression ~v ~expr) simplified in let%bind simplified = bind_map_option (s_type_expression ~substs) simplified in
ok @@ T.{ type_value'; simplified } ok @@ T.{ type_value'; simplified }
and s_literal ~v ~expr : T.literal w = function and s_literal : T.literal w = fun ~substs -> function
| T.Literal_unit -> | T.Literal_unit ->
let () = ignore (v, expr) in let () = ignore @@ substs in
ok @@ T.Literal_unit ok @@ T.Literal_unit
| (T.Literal_bool _ as x) | (T.Literal_bool _ as x)
| (T.Literal_int _ as x) | (T.Literal_int _ as x)
@ -131,142 +136,143 @@ module Substitution = struct
| (T.Literal_chain_id _ as x) | (T.Literal_chain_id _ as x)
| (T.Literal_operation _ as x) -> | (T.Literal_operation _ as x) ->
ok @@ x ok @@ x
and s_matching_expr ~v ~expr : T.matching_expr w = fun _ -> and s_matching_expr : T.matching_expr w = fun ~substs _ ->
let _TODO = v, expr in let _TODO = substs in
failwith "TODO: subst: unimplemented case s_matching" failwith "TODO: subst: unimplemented case s_matching"
and s_named_type_value ~v ~expr : T.named_type_value w = fun _ -> and s_named_type_value : T.named_type_value w = fun ~substs _ ->
let _TODO = v, expr in let _TODO = substs in
failwith "TODO: subst: unimplemented case s_named_type_value" failwith "TODO: subst: unimplemented case s_named_type_value"
and s_access_path ~v ~expr : T.access_path w = fun _ -> and s_access_path : T.access_path w = fun ~substs _ ->
let _TODO = v, expr in let _TODO = substs in
failwith "TODO: subst: unimplemented case s_access_path" failwith "TODO: subst: unimplemented case s_access_path"
and s_expression ~v ~expr : T.expression w = function and s_expression : T.expression w = fun ~(substs : substs) -> function
| T.E_literal x -> | T.E_literal x ->
let%bind x = s_literal ~v ~expr x in let%bind x = s_literal ~substs x in
ok @@ T.E_literal x ok @@ T.E_literal x
| T.E_constant (var, vals) -> | T.E_constant (var, vals) ->
let%bind var = s_build_in ~v ~expr var in let%bind var = s_build_in ~substs var in
let%bind vals = bind_map_list (s_annotated_expression ~v ~expr) vals in let%bind vals = bind_map_list (s_annotated_expression ~substs) vals in
ok @@ T.E_constant (var, vals) ok @@ T.E_constant (var, vals)
| T.E_variable tv -> | T.E_variable tv ->
let%bind tv = s_variable ~v ~expr tv in let%bind tv = s_variable ~substs tv in
ok @@ T.E_variable tv ok @@ T.E_variable tv
| T.E_application (val1 , val2) -> | T.E_application (val1 , val2) ->
let%bind val1 = s_annotated_expression ~v ~expr val1 in let%bind val1 = s_annotated_expression ~substs val1 in
let%bind val2 = s_annotated_expression ~v ~expr val2 in let%bind val2 = s_annotated_expression ~substs val2 in
ok @@ T.E_application (val1 , val2) ok @@ T.E_application (val1 , val2)
| T.E_lambda { binder; body } -> | T.E_lambda { binder; body } ->
let%bind binder = s_variable ~v ~expr binder in let%bind binder = s_variable ~substs binder in
let%bind body = s_annotated_expression ~v ~expr body in let%bind body = s_annotated_expression ~substs body in
ok @@ T.E_lambda { binder; body } ok @@ T.E_lambda { binder; body }
| T.E_let_in { binder; rhs; result; inline } -> | T.E_let_in { binder; rhs; result; inline } ->
let%bind binder = s_variable ~v ~expr binder in let%bind binder = s_variable ~substs binder in
let%bind rhs = s_annotated_expression ~v ~expr rhs in let%bind rhs = s_annotated_expression ~substs rhs in
let%bind result = s_annotated_expression ~v ~expr result in let%bind result = s_annotated_expression ~substs result in
ok @@ T.E_let_in { binder; rhs; result; inline } ok @@ T.E_let_in { binder; rhs; result; inline }
| T.E_tuple vals -> | T.E_tuple vals ->
let%bind vals = bind_map_list (s_annotated_expression ~v ~expr) vals in let%bind vals = bind_map_list (s_annotated_expression ~substs) vals in
ok @@ T.E_tuple vals ok @@ T.E_tuple vals
| T.E_tuple_accessor (val_, i) -> | T.E_tuple_accessor (val_, i) ->
let%bind val_ = s_annotated_expression ~v ~expr val_ in let%bind val_ = s_annotated_expression ~substs val_ in
let i = i in let i = i in
ok @@ T.E_tuple_accessor (val_, i) ok @@ T.E_tuple_accessor (val_, i)
| T.E_constructor (tvar, val_) -> | T.E_constructor (tvar, val_) ->
let%bind tvar = s_constructor ~v ~expr tvar in let%bind tvar = s_constructor ~substs tvar in
let%bind val_ = s_annotated_expression ~v ~expr val_ in let%bind val_ = s_annotated_expression ~substs val_ in
ok @@ T.E_constructor (tvar, val_) ok @@ T.E_constructor (tvar, val_)
| T.E_record aemap -> | T.E_record aemap ->
let _TODO = aemap in let _TODO = aemap in
failwith "TODO: subst in record" failwith "TODO: subst in record"
(* let%bind aemap = TSMap.bind_map_Map (fun ~k:key ~v:val_ -> (* let%bind aemap = TSMap.bind_map_Map (fun ~k:key ~v:val_ ->
* let key = s_type_variable ~v ~expr key in * let key = s_type_variable ~substs key in
* let val_ = s_annotated_expression ~v ~expr val_ in * let val_ = s_annotated_expression ~substs val_ in
* ok @@ (key , val_)) aemap in * ok @@ (key , val_)) aemap in
* ok @@ T.E_record aemap *) * ok @@ T.E_record aemap *)
| T.E_record_accessor (val_, l) -> | T.E_record_accessor (val_, l) ->
let%bind val_ = s_annotated_expression ~v ~expr val_ in let%bind val_ = s_annotated_expression ~substs val_ in
let%bind l = s_label ~v ~expr l in let l = l in (* Nothing to substitute, this is a label, not a type *)
ok @@ T.E_record_accessor (val_, l) ok @@ T.E_record_accessor (val_, l)
| T.E_record_update (r, ups) -> | T.E_record_update (r, ups) ->
let%bind r = s_annotated_expression ~v ~expr r in let%bind r = s_annotated_expression ~substs r in
let%bind ups = bind_map_list (fun (l,e) -> let%bind e = s_annotated_expression ~v ~expr e in ok (l,e)) ups in let%bind ups = bind_map_list (fun (l,e) -> let%bind e = s_annotated_expression ~substs e in ok (l,e)) ups in
ok @@ T.E_record_update (r,ups) ok @@ T.E_record_update (r,ups)
| T.E_map val_val_list -> | T.E_map val_val_list ->
let%bind val_val_list = bind_map_list (fun (val1 , val2) -> let%bind val_val_list = bind_map_list (fun (val1 , val2) ->
let%bind val1 = s_annotated_expression ~v ~expr val1 in let%bind val1 = s_annotated_expression ~substs val1 in
let%bind val2 = s_annotated_expression ~v ~expr val2 in let%bind val2 = s_annotated_expression ~substs val2 in
ok @@ (val1 , val2) ok @@ (val1 , val2)
) val_val_list in ) val_val_list in
ok @@ T.E_map val_val_list ok @@ T.E_map val_val_list
| T.E_big_map val_val_list -> | T.E_big_map val_val_list ->
let%bind val_val_list = bind_map_list (fun (val1 , val2) -> let%bind val_val_list = bind_map_list (fun (val1 , val2) ->
let%bind val1 = s_annotated_expression ~v ~expr val1 in let%bind val1 = s_annotated_expression ~substs val1 in
let%bind val2 = s_annotated_expression ~v ~expr val2 in let%bind val2 = s_annotated_expression ~substs val2 in
ok @@ (val1 , val2) ok @@ (val1 , val2)
) val_val_list in ) val_val_list in
ok @@ T.E_big_map val_val_list ok @@ T.E_big_map val_val_list
| T.E_list vals -> | T.E_list vals ->
let%bind vals = bind_map_list (s_annotated_expression ~v ~expr) vals in let%bind vals = bind_map_list (s_annotated_expression ~substs) vals in
ok @@ T.E_list vals ok @@ T.E_list vals
| T.E_set vals -> | T.E_set vals ->
let%bind vals = bind_map_list (s_annotated_expression ~v ~expr) vals in let%bind vals = bind_map_list (s_annotated_expression ~substs) vals in
ok @@ T.E_set vals ok @@ T.E_set vals
| T.E_look_up (val1, val2) -> | T.E_look_up (val1, val2) ->
let%bind val1 = s_annotated_expression ~v ~expr val1 in let%bind val1 = s_annotated_expression ~substs val1 in
let%bind val2 = s_annotated_expression ~v ~expr val2 in let%bind val2 = s_annotated_expression ~substs val2 in
ok @@ T.E_look_up (val1 , val2) ok @@ T.E_look_up (val1 , val2)
| T.E_matching (val_ , matching_expr) -> | T.E_matching (val_ , matching_expr) ->
let%bind val_ = s_annotated_expression ~v ~expr val_ in let%bind val_ = s_annotated_expression ~substs val_ in
let%bind matching = s_matching_expr ~v ~expr matching_expr in let%bind matching = s_matching_expr ~substs matching_expr in
ok @@ T.E_matching (val_ , matching) ok @@ T.E_matching (val_ , matching)
| T.E_sequence (val1, val2) -> | T.E_sequence (val1, val2) ->
let%bind val1 = s_annotated_expression ~v ~expr val1 in let%bind val1 = s_annotated_expression ~substs val1 in
let%bind val2 = s_annotated_expression ~v ~expr val2 in let%bind val2 = s_annotated_expression ~substs val2 in
ok @@ T.E_sequence (val1 , val2) ok @@ T.E_sequence (val1 , val2)
| T.E_loop (val1, val2) -> | T.E_loop (val1, val2) ->
let%bind val1 = s_annotated_expression ~v ~expr val1 in let%bind val1 = s_annotated_expression ~substs val1 in
let%bind val2 = s_annotated_expression ~v ~expr val2 in let%bind val2 = s_annotated_expression ~substs val2 in
ok @@ T.E_loop (val1 , val2) ok @@ T.E_loop (val1 , val2)
| T.E_assign (named_tval, access_path, val_) -> | T.E_assign (named_tval, access_path, val_) ->
let%bind named_tval = s_named_type_value ~v ~expr named_tval in let%bind named_tval = s_named_type_value ~substs named_tval in
let%bind access_path = s_access_path ~v ~expr access_path in let%bind access_path = s_access_path ~substs access_path in
let%bind val_ = s_annotated_expression ~v ~expr val_ in let%bind val_ = s_annotated_expression ~substs val_ in
ok @@ T.E_assign (named_tval, access_path, val_) ok @@ T.E_assign (named_tval, access_path, val_)
and s_annotated_expression ~v ~expr : T.annotated_expression w = fun { expression; type_annotation; environment; location } -> and s_annotated_expression : T.annotated_expression w = fun ~substs { expression; type_annotation; environment; location } ->
let%bind expression = s_expression ~v ~expr expression in let%bind expression = s_expression ~substs expression in
let%bind type_annotation = s_type_value ~v ~expr type_annotation in let%bind type_annotation = s_type_value ~substs type_annotation in
let%bind environment = s_full_environment ~v ~expr environment in let%bind environment = s_full_environment ~substs environment in
let location = location in let location = location in
ok T.{ expression; type_annotation; environment; location } ok T.{ expression; type_annotation; environment; location }
and s_named_expression ~v ~expr : T.named_expression w = fun { name; annotated_expression } -> and s_named_expression : T.named_expression w = fun ~substs { name; annotated_expression } ->
let%bind name = s_variable ~v ~expr name in let name = name in (* Nothing to substitute, this is a variable name *)
let%bind annotated_expression = s_annotated_expression ~v ~expr annotated_expression in let%bind annotated_expression = s_annotated_expression ~substs annotated_expression in
ok T.{ name; annotated_expression } ok T.{ name; annotated_expression }
and s_declaration ~v ~expr : T.declaration w = and s_declaration : T.declaration w = fun ~substs ->
function function
Ast_typed.Declaration_constant (e, i, (env1, env2)) -> Ast_typed.Declaration_constant (e, inline, (env1, env2)) ->
let%bind e = s_named_expression ~v ~expr e in let%bind e = s_named_expression ~substs e in
let%bind env1 = s_full_environment ~v ~expr env1 in let%bind env1 = s_full_environment ~substs env1 in
let%bind env2 = s_full_environment ~v ~expr env2 in let%bind env2 = s_full_environment ~substs env2 in
ok @@ Ast_typed.Declaration_constant (e, i, (env1, env2)) ok @@ Ast_typed.Declaration_constant (e, inline, (env1, env2))
and s_declaration_wrap ~v ~expr : T.declaration Location.wrap w = fun d -> and s_declaration_wrap : T.declaration Location.wrap w = fun ~substs d ->
Trace.bind_map_location (s_declaration ~v ~expr) d Trace.bind_map_location (s_declaration ~substs) d
(* Replace the type variable ~v with ~expr everywhere within the (* Replace the type variable ~v with ~expr everywhere within the
program ~p. TODO: issues with scoping/shadowing. *) program ~p. TODO: issues with scoping/shadowing. *)
and program ~(p : Ast_typed.program) ~(v:type_variable) ~expr : Ast_typed.program Trace.result = and s_program : Ast_typed.program w = fun ~substs p ->
Trace.bind_map_list (s_declaration_wrap ~v ~expr) p Trace.bind_map_list (s_declaration_wrap ~substs) p
(* (*
Computes `P[v := expr]`. Computes `P[v := expr]`.
*) *)
and type_value ~tv ~v ~expr = and type_value ~tv ~substs =
let self tv = type_value ~tv ~v ~expr in let self tv = type_value ~tv ~substs in
let (v, expr) = substs in
match tv with match tv with
| P_variable v' when v' = v -> expr | P_variable v' when v' = v -> expr
| P_variable _ -> tv | P_variable _ -> tv
@ -279,7 +285,7 @@ module Substitution = struct
P_apply ab' P_apply ab'
) )
| P_forall p -> ( | P_forall p -> (
let aux c = constraint_ ~c ~v ~expr in let aux c = constraint_ ~c ~substs in
let constraints = List.map aux p.constraints in let constraints = List.map aux p.constraints in
if (p.binder = v) then ( if (p.binder = v) then (
P_forall { p with constraints } P_forall { p with constraints }
@ -289,31 +295,33 @@ module Substitution = struct
) )
) )
and constraint_ ~c ~v ~expr = and constraint_ ~c ~substs =
match c with match c with
| C_equation ab -> ( | C_equation ab -> (
let ab' = pair_map (fun tv -> type_value ~tv ~v ~expr) ab in let ab' = pair_map (fun tv -> type_value ~tv ~substs) ab in
C_equation ab' C_equation ab'
) )
| C_typeclass (tvs , tc) -> ( | C_typeclass (tvs , tc) -> (
let tvs' = List.map (fun tv -> type_value ~tv ~v ~expr) tvs in let tvs' = List.map (fun tv -> type_value ~tv ~substs) tvs in
let tc' = typeclass ~tc ~v ~expr in let tc' = typeclass ~tc ~substs in
C_typeclass (tvs' , tc') C_typeclass (tvs' , tc')
) )
| C_access_label (tv , l , v') -> ( | C_access_label (tv , l , v') -> (
let tv' = type_value ~tv ~v ~expr in let tv' = type_value ~tv ~substs in
C_access_label (tv' , l , v') C_access_label (tv' , l , v')
) )
and typeclass ~tc ~v ~expr = and typeclass ~tc ~substs =
List.map (List.map (fun tv -> type_value ~tv ~v ~expr)) tc List.map (List.map (fun tv -> type_value ~tv ~substs)) tc
let program = s_program
(* Performs beta-reduction at the root of the type *) (* Performs beta-reduction at the root of the type *)
let eval_beta_root ~(tv : type_value) = let eval_beta_root ~(tv : type_value) =
match tv with match tv with
P_apply (P_forall { binder; constraints; body }, arg) -> P_apply (P_forall { binder; constraints; body }, arg) ->
let constraints = List.map (fun c -> constraint_ ~c ~v:binder ~expr:arg) constraints in let constraints = List.map (fun c -> constraint_ ~c ~substs:(mk_substs ~v:binder ~expr:arg)) constraints in
(type_value ~tv:body ~v:binder ~expr:arg , constraints) (type_value ~tv:body ~substs:(mk_substs ~v:binder ~expr:arg) , constraints)
| _ -> (tv , []) | _ -> (tv , [])
end end