diff --git a/src/main/compile/of_core.ml b/src/main/compile/of_core.ml index e6f0dfbba..c7992399e 100644 --- a/src/main/compile/of_core.ml +++ b/src/main/compile/of_core.ml @@ -4,7 +4,7 @@ type form = | Contract of string | Env -let compile (cform: form) (program : Ast_core.program) : (Ast_typed.program * Typer.Solver.state) result = +let compile (cform: form) (program : Ast_core.program) : (Ast_typed.program * Ast_typed.typer_state) result = let%bind (prog_typed , state) = Typer.type_program program in let () = Typer.Solver.discard_state state in let%bind applied = Self_ast_typed.all_program prog_typed in @@ -13,8 +13,8 @@ let compile (cform: form) (program : Ast_core.program) : (Ast_typed.program * Ty | Env -> ok applied in ok @@ (applied', state) -let compile_expression ?(env = Ast_typed.Environment.full_empty) ~(state : Typer.Solver.state) (e : Ast_core.expression) - : (Ast_typed.expression * Typer.Solver.state) result = +let compile_expression ?(env = Ast_typed.Environment.full_empty) ~(state : Ast_typed.typer_state) (e : Ast_core.expression) + : (Ast_typed.expression * Ast_typed.typer_state) result = let%bind (ae_typed,state) = Typer.type_expression_subst env state e in let () = Typer.Solver.discard_state state in let%bind ae_typed' = Self_ast_typed.all_expression ae_typed in diff --git a/src/passes/8-typer-new/PP.ml b/src/passes/8-typer-new/PP.ml index dbca3e194..db1512f19 100644 --- a/src/passes/8-typer-new/PP.ml +++ b/src/passes/8-typer-new/PP.ml @@ -1,35 +1,36 @@ -open Solver +open Ast_typed open Format +module UF = UnionFind.Poly2 let type_constraint : _ -> type_constraint_simpl -> unit = fun ppf -> function |SC_Constructor { tv; c_tag; tv_list=_ } -> let ct = match c_tag with - | Solver.Core.C_arrow -> "arrow" - | Solver.Core.C_option -> "option" - | Solver.Core.C_record -> failwith "record" - | Solver.Core.C_variant -> failwith "variant" - | Solver.Core.C_map -> "map" - | Solver.Core.C_big_map -> "big_map" - | Solver.Core.C_list -> "list" - | Solver.Core.C_set -> "set" - | Solver.Core.C_unit -> "unit" - | Solver.Core.C_string -> "string" - | Solver.Core.C_nat -> "nat" - | Solver.Core.C_mutez -> "mutez" - | Solver.Core.C_timestamp -> "timestamp" - | Solver.Core.C_int -> "int" - | Solver.Core.C_address -> "address" - | Solver.Core.C_bytes -> "bytes" - | Solver.Core.C_key_hash -> "key_hash" - | Solver.Core.C_key -> "key" - | Solver.Core.C_signature -> "signature" - | Solver.Core.C_operation -> "operation" - | Solver.Core.C_contract -> "contract" - | Solver.Core.C_chain_id -> "chain_id" + | C_arrow -> "arrow" + | C_option -> "option" + | C_record -> failwith "record" + | C_variant -> failwith "variant" + | C_map -> "map" + | C_big_map -> "big_map" + | C_list -> "list" + | C_set -> "set" + | C_unit -> "unit" + | C_string -> "string" + | C_nat -> "nat" + | C_mutez -> "mutez" + | C_timestamp -> "timestamp" + | C_int -> "int" + | C_address -> "address" + | C_bytes -> "bytes" + | C_key_hash -> "key_hash" + | C_key -> "key" + | C_signature -> "signature" + | C_operation -> "operation" + | C_contract -> "contract" + | C_chain_id -> "chain_id" in fprintf ppf "CTOR %a %s()" Var.pp tv ct - |SC_Alias (a, b) -> fprintf ppf "Alias %a %a" Var.pp a Var.pp b + |SC_Alias { a; b } -> fprintf ppf "Alias %a %a" Var.pp a Var.pp b |SC_Poly _ -> fprintf ppf "Poly" |SC_Typeclass _ -> fprintf ppf "TC" @@ -47,6 +48,6 @@ let already_selected : _ -> already_selected -> unit = fun ppf already_selected let _ = already_selected in fprintf ppf "ALREADY_SELECTED" -let state : _ -> state -> unit = fun ppf state -> +let state : _ -> typer_state -> unit = fun ppf state -> let { structured_dbs=a ; already_selected=b } = state in fprintf ppf "STATE %a %a" structured_dbs a already_selected b diff --git a/src/passes/8-typer-new/solver.ml b/src/passes/8-typer-new/solver.ml index e72f49e22..eeb84c39d 100644 --- a/src/passes/8-typer-new/solver.ml +++ b/src/passes/8-typer-new/solver.ml @@ -1,10 +1,15 @@ open Trace module Core = Typesystem.Core +module Map = RedBlackTrees.PolyMap +module Set = RedBlackTrees.PolySet +module UF = UnionFind.Poly2 module Wrap = Wrap open Wrap +open Ast_typed.Misc +(* TODO: remove this, it's not used anymore *) module TypeVariable = struct type t = Core.type_variable @@ -13,14 +18,6 @@ struct end -module UF = UnionFind.Partition0.Make(TypeVariable) - -type unionfind = UF.t - -(* end unionfind *) - -(* representant for an equivalence class of type variables *) -module TypeVariableMap = Map.Make(TypeVariable) (* @@ -59,48 +56,7 @@ Workflow: *) -open Core - -type structured_dbs = { - all_constraints : type_constraint_simpl list ; - aliases : unionfind ; - (* assignments (passive data structure). - Now: just a map from unification vars to types (pb: what about partial types?) - maybe just local assignments (allow only vars as children of pair(α,β)) *) - (* TODO: the rhs of the map should not repeat the variable name. *) - assignments : c_constructor_simpl TypeVariableMap.t ; - grouped_by_variable : constraints TypeVariableMap.t ; (* map from (unionfind) variables to constraints containing them *) - cycle_detection_toposort : unit ; (* example of structured db that we'll add later *) -} - -and constraints = { - (* If implemented in a language with decent sets, these should be sets not lists. *) - constructor : c_constructor_simpl list ; (* List of ('a = constructor(args…)) constraints *) - poly : c_poly_simpl list ; (* List of ('a = forall 'b, some_type) constraints *) - tc : c_typeclass_simpl list ; (* List of (typeclass(args…)) constraints *) -} - -and c_constructor_simpl = { - tv : type_variable; - c_tag : constant_tag; - tv_list : type_variable list; -} -(* copy-pasted from core.ml *) -and c_const = (type_variable * type_expression) -and c_equation = (type_expression * type_expression) -and c_typeclass_simpl = { - tc : typeclass ; - args : type_variable list ; -} -and c_poly_simpl = { - tv : type_variable ; - forall : p_forall ; -} -and type_constraint_simpl = - SC_Constructor of c_constructor_simpl (* α = ctor(β, …) *) - | SC_Alias of (type_variable * type_variable) (* α = β *) - | SC_Poly of c_poly_simpl (* α = forall β, δ where δ can be a more complex type *) - | SC_Typeclass of c_typeclass_simpl (* TC(α, …) *) +open Ast_typed.Types module UnionFindWrapper = struct (* Light wrapper for API for grouped_by_variable in the structured @@ -109,7 +65,7 @@ module UnionFindWrapper = struct fun variable dbs -> let variable , aliases = UF.get_or_set variable dbs.aliases in let dbs = { dbs with aliases } in - match TypeVariableMap.find_opt variable dbs.grouped_by_variable with + match Map.find_opt variable dbs.grouped_by_variable with Some l -> l | None -> { constructor = [] ; @@ -122,9 +78,9 @@ module UnionFindWrapper = struct let dbs = { dbs with aliases } in *) let variable_repr , aliases = UF.get_or_set variable dbs.aliases in let dbs = { dbs with aliases } in - let grouped_by_variable = TypeVariableMap.update variable_repr (function + let grouped_by_variable = Map.update variable_repr (function None -> Some c - | Some x -> Some { + | Some (x : constraints) -> Some { constructor = c.constructor @ x.constructor ; poly = c.poly @ x.poly ; tc = c.tc @ x.tc ; @@ -150,7 +106,7 @@ module UnionFindWrapper = struct (* Replace the two entries in grouped_by_variable by a single one *) ( let get_constraints ab = - match TypeVariableMap.find_opt ab dbs.grouped_by_variable with + match Map.find_opt ab dbs.grouped_by_variable with | Some x -> x | None -> { constructor = [] ; poly = [] ; tc = [] } in let constraints_a = get_constraints variable_repr_a in @@ -161,10 +117,10 @@ module UnionFindWrapper = struct tc = constraints_a.tc @ constraints_b.tc ; } in let grouped_by_variable = - TypeVariableMap.add variable_repr_a all_constraints dbs.grouped_by_variable in + Map.add variable_repr_a all_constraints dbs.grouped_by_variable in let dbs = { dbs with grouped_by_variable} in let grouped_by_variable = - TypeVariableMap.remove variable_repr_b dbs.grouped_by_variable in + Map.remove variable_repr_b dbs.grouped_by_variable in let dbs = { dbs with grouped_by_variable} in dbs ) @@ -207,7 +163,7 @@ let normalizer_grouped_by_variable : (type_constraint_simpl , type_constraint_si SC_Constructor ({tv ; c_tag = _ ; tv_list} as c) -> store_constraint (tv :: tv_list) {constructor = [c] ; poly = [] ; tc = []} | SC_Typeclass ({tc = _ ; args} as c) -> store_constraint args {constructor = [] ; poly = [] ; tc = [c]} | SC_Poly ({tv; forall = _} as c) -> store_constraint [tv] {constructor = [] ; poly = [c] ; tc = []} - | SC_Alias (a , b) -> UnionFindWrapper.merge_constraints a b dbs + | SC_Alias { a; b } -> UnionFindWrapper.merge_constraints a b dbs in (dbs , [new_constraint]) (** Stores the first assinment ('a = ctor('b, …)) that is encountered. @@ -219,7 +175,7 @@ let normalizer_assignments : (type_constraint_simpl , type_constraint_simpl) nor fun dbs new_constraint -> match new_constraint with | SC_Constructor ({tv ; c_tag = _ ; tv_list = _} as c) -> - let assignments = TypeVariableMap.update tv (function None -> Some c | e -> e) dbs.assignments in + let assignments = Map.update tv (function None -> Some c | e -> e) dbs.assignments in let dbs = {dbs with assignments} in (dbs , [new_constraint]) | _ -> @@ -254,47 +210,47 @@ let rec normalizer_simpl : (type_constraint , type_constraint_simpl) normalizer fun dbs new_constraint -> let insert_fresh a b = let fresh = Core.fresh_type_variable () in - let (dbs , cs1) = normalizer_simpl dbs (C_equation (P_variable fresh, a)) in - let (dbs , cs2) = normalizer_simpl dbs (C_equation (P_variable fresh, b)) in + let (dbs , cs1) = normalizer_simpl dbs (c_equation (P_variable fresh) a) in + let (dbs , cs2) = normalizer_simpl dbs (c_equation (P_variable fresh) b) in (dbs , cs1 @ cs2) in let split_constant a c_tag args = let fresh_vars = List.map (fun _ -> Core.fresh_type_variable ()) args in - let fresh_eqns = List.map (fun (v,t) -> C_equation (P_variable v, t)) (List.combine fresh_vars args) in + let fresh_eqns = List.map (fun (v,t) -> c_equation (P_variable v) t) (List.combine fresh_vars args) in let (dbs , recur) = List.fold_map_acc normalizer_simpl dbs fresh_eqns in (dbs , [SC_Constructor {tv=a;c_tag;tv_list=fresh_vars}] @ List.flatten recur) in let gather_forall a forall = (dbs , [SC_Poly { tv=a; forall }]) in - let gather_alias a b = (dbs , [SC_Alias (a, b)]) in + let gather_alias a b = (dbs , [SC_Alias { a ; b }]) in let reduce_type_app a b = let (reduced, new_constraints) = check_applied @@ type_level_eval b in let (dbs , recur) = List.fold_map_acc normalizer_simpl dbs new_constraints in - let (dbs , resimpl) = normalizer_simpl dbs (C_equation (a , reduced)) in (* Note: this calls recursively but cant't fall in the same case. *) + let (dbs , resimpl) = normalizer_simpl dbs (c_equation a reduced) in (* Note: this calls recursively but cant't fall in the same case. *) (dbs , resimpl @ List.flatten recur) in let split_typeclass args tc = let fresh_vars = List.map (fun _ -> Core.fresh_type_variable ()) args in - let fresh_eqns = List.map (fun (v,t) -> C_equation (P_variable v, t)) (List.combine fresh_vars args) in + let fresh_eqns = List.map (fun (v,t) -> c_equation (P_variable v) t) (List.combine fresh_vars args) in let (dbs , recur) = List.fold_map_acc normalizer_simpl dbs fresh_eqns in (dbs, [SC_Typeclass { tc ; args = fresh_vars }] @ List.flatten recur) in match new_constraint with (* break down (forall 'b, body = forall 'c, body') into ('a = forall 'b, body and 'a = forall 'c, body')) *) - | C_equation ((P_forall _ as a), (P_forall _ as b)) -> insert_fresh a b + | C_equation {aval=(P_forall _ as a); bval=(P_forall _ as b)} -> insert_fresh a b (* break down (forall 'b, body = c(args)) into ('a = forall 'b, body and 'a = c(args)) *) - | C_equation ((P_forall _ as a), (P_constant _ as b)) -> insert_fresh a b + | C_equation {aval=(P_forall _ as a); bval=(P_constant _ as b)} -> insert_fresh a b (* break down (c(args) = c'(args')) into ('a = c(args) and 'a = c'(args')) *) - | C_equation ((P_constant _ as a), (P_constant _ as b)) -> insert_fresh a b + | C_equation {aval=(P_constant _ as a); bval=(P_constant _ as b)} -> insert_fresh a b (* break down (c(args) = forall 'b, body) into ('a = c(args) and 'a = forall 'b, body) *) - | C_equation ((P_constant _ as a), (P_forall _ as b)) -> insert_fresh a b - | C_equation ((P_forall forall), (P_variable b)) -> gather_forall b forall - | C_equation (P_variable a, P_forall forall) -> gather_forall a forall - | C_equation (P_variable a, P_variable b) -> gather_alias a b - | C_equation (P_variable a, P_constant (c_tag , args)) -> split_constant a c_tag args - | C_equation (P_constant (c_tag , args), P_variable b) -> split_constant b c_tag args + | C_equation {aval=(P_constant _ as a); bval=(P_forall _ as b)} -> insert_fresh a b + | C_equation {aval=(P_forall forall); bval=(P_variable b)} -> gather_forall b forall + | C_equation {aval=P_variable a; bval=P_forall forall} -> gather_forall a forall + | C_equation {aval=P_variable a; bval=P_variable b} -> gather_alias a b + | C_equation {aval=P_variable a; bval=P_constant { p_ctor_tag; p_ctor_args }} -> split_constant a p_ctor_tag p_ctor_args + | C_equation {aval=P_constant {p_ctor_tag; p_ctor_args}; bval=P_variable b} -> split_constant b p_ctor_tag p_ctor_args (* Reduce the type-level application, and simplify the resulting constraint + the extra constraints (typeclasses) that appeared at the forall binding site *) - | C_equation ((_ as a), (P_apply _ as b)) -> reduce_type_app a b - | C_equation ((P_apply _ as a), (_ as b)) -> reduce_type_app b a + | C_equation {aval=(_ as a); bval=(P_apply _ as b)} -> reduce_type_app a b + | C_equation {aval=(P_apply _ as a); bval=(_ as b)} -> reduce_type_app b a (* break down (TC(args)) into (TC('a, …) and ('a = arg) …) *) - | C_typeclass (args, tc) -> split_typeclass args tc - | C_access_label (tv, label, result) -> let _todo = ignore (tv, label, result) in failwith "TODO" + | C_typeclass { tc_args; typeclass } -> split_typeclass tc_args typeclass + | C_access_label { c_access_label_tval; accessor; c_access_label_tvar } -> let _todo = ignore (c_access_label_tval, accessor, c_access_label_tvar) in failwith "TODO" (* tv, label, result *) (* Random notes from live discussion. Kept here to include bits as a rationale later on / remind me of the discussion in the short term. * Feel free to erase if it rots here for too long. @@ -366,7 +322,6 @@ type 'selector_output propagator = 'selector_output -> structured_dbs -> new_con (* selector / propagation rule for breaking down composite types * For now: break pair(a, b) = pair(c, d) into a = c, b = d *) -type output_break_ctor = { a_k_var : c_constructor_simpl ; a_k'_var' : c_constructor_simpl } let selector_break_ctor : (type_constraint_simpl, output_break_ctor) selector = (* find two rules with the shape a = k(var …) and a = k'(var' …) *) fun type_constraint_simpl dbs -> @@ -479,28 +434,28 @@ let compare_simple_c_constant = function has been copied here. *) let debug_pp_constant : _ -> constant_tag -> unit = fun ppf c_tag -> let ct = match c_tag with - | Core.C_arrow -> "arrow" - | Core.C_option -> "option" - | Core.C_record -> failwith "record" - | Core.C_variant -> failwith "variant" - | Core.C_map -> "map" - | Core.C_big_map -> "big_map" - | Core.C_list -> "list" - | Core.C_set -> "set" - | Core.C_unit -> "unit" - | Core.C_string -> "string" - | Core.C_nat -> "nat" - | Core.C_mutez -> "mutez" - | Core.C_timestamp -> "timestamp" - | Core.C_int -> "int" - | Core.C_address -> "address" - | Core.C_bytes -> "bytes" - | Core.C_key_hash -> "key_hash" - | Core.C_key -> "key" - | Core.C_signature -> "signature" - | Core.C_operation -> "operation" - | Core.C_contract -> "contract" - | Core.C_chain_id -> "chain_id" + | T.C_arrow -> "arrow" + | T.C_option -> "option" + | T.C_record -> failwith "record" + | T.C_variant -> failwith "variant" + | T.C_map -> "map" + | T.C_big_map -> "big_map" + | T.C_list -> "list" + | T.C_set -> "set" + | T.C_unit -> "unit" + | T.C_string -> "string" + | T.C_nat -> "nat" + | T.C_mutez -> "mutez" + | T.C_timestamp -> "timestamp" + | T.C_int -> "int" + | T.C_address -> "address" + | T.C_bytes -> "bytes" + | T.C_key_hash -> "key_hash" + | T.C_key -> "key" + | T.C_signature -> "signature" + | T.C_operation -> "operation" + | T.C_contract -> "contract" + | T.C_chain_id -> "chain_id" in Format.fprintf ppf "%s" ct @@ -515,7 +470,7 @@ let propagator_break_ctor : output_break_ctor propagator = (* produce constraints: *) (* a.tv = b.tv *) - let eq1 = C_equation (P_variable a.tv, P_variable b.tv) in + let eq1 = c_equation (P_variable a.tv) (P_variable b.tv) in (* a.c_tag = b.c_tag *) if (compare_simple_c_constant a.c_tag b.c_tag) <> 0 then failwith (Format.asprintf "type error: incompatible types, not same ctor %a vs. %a (compare returns %d)" debug_pp_c_constructor_simpl a debug_pp_c_constructor_simpl b (compare_simple_c_constant a.c_tag b.c_tag)) @@ -524,14 +479,13 @@ let propagator_break_ctor : output_break_ctor propagator = if List.length a.tv_list <> List.length b.tv_list then failwith "type error: incompatible types, not same length" else - let eqs3 = List.map2 (fun aa bb -> C_equation (P_variable aa, P_variable bb)) a.tv_list b.tv_list in + let eqs3 = List.map2 (fun aa bb -> c_equation (P_variable aa) (P_variable bb)) a.tv_list b.tv_list in let eqs = eq1 :: eqs3 in (eqs , []) (* no new assignments *) (* TODO : with our selectors, the selection depends on the order in which the constraints are added :-( :-( :-( :-( We need to return a lazy stream of constraints. *) -type output_specialize1 = { poly : c_poly_simpl ; a_k_var : c_constructor_simpl } let ( (function [] -> 0 | _::_ -> -1) (* This follows the behaviour of Pervasives.compare for lists of different length *) let compare_type_variable a b = Var.compare a b -let compare_label (a:accessor) (b:accessor) = +let compare_label (a:label) (b:label) = let Label a = a in let Label b = b in String.compare a b @@ -564,29 +518,29 @@ and compare_type_expression = function | P_variable b -> compare_type_variable a b | P_constant _ -> -1 | P_apply _ -> -1) - | P_constant (a1, a2) -> (function + | P_constant { p_ctor_tag=a1; p_ctor_args=a2 } -> (function | P_forall _ -> 1 | P_variable _ -> 1 - | P_constant (b1, b2) -> compare_simple_c_constant a1 b1 compare_list compare_type_expression a2 b2 + | P_constant { p_ctor_tag=b1; p_ctor_args=b2 } -> compare_simple_c_constant a1 b1 compare_list compare_type_expression a2 b2 | P_apply _ -> -1) - | P_apply (a1, a2) -> (function + | P_apply { tf=a1; targ=a2 } -> (function | P_forall _ -> 1 | P_variable _ -> 1 | P_constant _ -> 1 - | P_apply (b1, b2) -> compare_type_expression a1 b1 compare_type_expression a2 b2) + | P_apply { tf=b1; targ=b2 } -> compare_type_expression a1 b1 compare_type_expression a2 b2) and compare_type_constraint = function - | C_equation (a1, a2) -> (function - | C_equation (b1, b2) -> compare_type_expression a1 b1 compare_type_expression a2 b2 + | C_equation { aval=a1; bval=a2 } -> (function + | C_equation { aval=b1; bval=b2 } -> compare_type_expression a1 b1 compare_type_expression a2 b2 | C_typeclass _ -> -1 | C_access_label _ -> -1) - | C_typeclass (a1, a2) -> (function + | C_typeclass { tc_args=a1; typeclass=a2 } -> (function | C_equation _ -> 1 - | C_typeclass (b1, b2) -> compare_list compare_type_expression a1 b1 compare_typeclass a2 b2 + | C_typeclass { tc_args=b1; typeclass=b2 } -> compare_list compare_type_expression a1 b1 compare_typeclass a2 b2 | C_access_label _ -> -1) - | C_access_label (a1, a2, a3) -> (function + | C_access_label { c_access_label_tval=a1; accessor=a2; c_access_label_tvar=a3 } -> (function | C_equation _ -> 1 | C_typeclass _ -> 1 - | C_access_label (b1, b2, b3) -> compare_type_expression a1 b1 compare_label a2 b2 compare_type_variable a3 b3) + | C_access_label { c_access_label_tval=b1; accessor=b2; c_access_label_tvar=b3 } -> compare_type_expression a1 b1 compare_label a2 b2 compare_type_variable a3 b3) let compare_type_constraint_list = compare_list compare_type_constraint let compare_p_forall { binder = a1; constraints = a2; body = a3 } @@ -607,17 +561,6 @@ let compare_output_specialize1 { poly = a1; a_k_var = a2 } { poly = b1; a_k_var let compare_output_break_ctor { a_k_var=a1; a_k'_var'=a2 } { a_k_var=b1; a_k'_var'=b2 } = compare_c_constructor_simpl a1 b1 compare_c_constructor_simpl a2 b2 -module OutputSpecialize1 : (Set.OrderedType with type t = output_specialize1) = struct - type t = output_specialize1 - let compare = compare_output_specialize1 -end - - -module BreakCtor : (Set.OrderedType with type t = output_break_ctor) = struct - type t = output_break_ctor - let compare = compare_output_break_ctor -end - let selector_specialize1 : (type_constraint_simpl, output_specialize1) selector = (* find two rules with the shape (a = forall b, d) and a = k'(var' …) or vice versa *) (* TODO: do the same for two rules with the shape (a = forall b, d) and tc(a…) *) @@ -651,23 +594,21 @@ let propagator_specialize1 : output_specialize1 propagator = let fresh_existential = Core.fresh_type_variable () in (* Produce the constraint (b.tv = a.body[a.binder |-> fresh_existential]) The substitution is obtained by immediately applying the forall. *) - let apply = (P_apply (P_forall a.forall , P_variable fresh_existential)) in + let apply = (P_apply {tf = (P_forall a.forall); targ = P_variable fresh_existential}) in let (reduced, new_constraints) = check_applied @@ type_level_eval apply in - let eq1 = C_equation (P_variable b.tv, reduced) in + let eq1 = c_equation (P_variable b.tv) reduced in let eqs = eq1 :: new_constraints in (eqs, []) (* no new assignments *) -module M (BlaBla : Set.OrderedType) = struct - module AlreadySelected = Set.Make(BlaBla) - - let select_and_propagate : ('old_input, 'selector_output) selector -> BlaBla.t propagator -> _ -> 'a -> structured_dbs -> _ * new_constraints * new_assignments = + let select_and_propagate : ('old_input, 'selector_output) selector -> _ propagator -> _ -> 'a -> structured_dbs -> _ * new_constraints * new_assignments = + let mem elt set = match RedBlackTrees.PolySet.find_opt elt set with None -> false | Some _ -> true in fun selector propagator -> fun already_selected old_type_constraint dbs -> (* TODO: thread some state to know which selector outputs were already seen *) match selector old_type_constraint dbs with WasSelected selected_outputs -> (* TODO: fold instead. *) - let (already_selected , selected_outputs) = List.fold_left (fun (already_selected, selected_outputs) elt -> if AlreadySelected.mem elt already_selected then (AlreadySelected.add elt already_selected , elt :: selected_outputs) + let (already_selected , selected_outputs) = List.fold_left (fun (already_selected, selected_outputs) elt -> if mem elt already_selected then (RedBlackTrees.PolySet.add elt already_selected , elt :: selected_outputs) else (already_selected , selected_outputs)) (already_selected , selected_outputs) selected_outputs in (* Call the propagation rule *) let new_contraints_and_assignments = List.map (fun s -> propagator s dbs) selected_outputs in @@ -676,25 +617,16 @@ module M (BlaBla : Set.OrderedType) = struct (already_selected , List.flatten new_constraints , List.flatten new_assignments) | WasNotSelected -> (already_selected, [] , []) -end -module M_break_ctor = M(BreakCtor) -module M_specialize1 = M(OutputSpecialize1) - -let select_and_propagate_break_ctor = M_break_ctor.select_and_propagate selector_break_ctor propagator_break_ctor -let select_and_propagate_specialize1 = M_specialize1.select_and_propagate selector_specialize1 propagator_specialize1 - -type already_selected = { - break_ctor : M_break_ctor.AlreadySelected.t ; - specialize1 : M_specialize1.AlreadySelected.t ; -} +let select_and_propagate_break_ctor = select_and_propagate selector_break_ctor propagator_break_ctor +let select_and_propagate_specialize1 = select_and_propagate selector_specialize1 propagator_specialize1 (* Takes a constraint, applies all selector+propagator pairs to it. Keeps track of which constraints have already been selected. *) let select_and_propagate_all' : _ -> type_constraint_simpl selector_input -> structured_dbs -> _ * new_constraints * structured_dbs = let aux sel_propag new_constraint (already_selected , new_constraints , dbs) = let (already_selected , new_constraints', new_assignments) = sel_propag already_selected new_constraint dbs in - let assignments = List.fold_left (fun acc ({tv;c_tag=_;tv_list=_} as ele) -> TypeVariableMap.update tv (function None -> Some ele | x -> x) acc) dbs.assignments new_assignments in + let assignments = List.fold_left (fun acc ({tv;c_tag=_;tv_list=_} as ele) -> Map.update tv (function None -> Some ele | x -> x) acc) dbs.assignments new_assignments in let dbs = { dbs with assignments } in (already_selected , new_constraints' @ new_constraints , dbs) in @@ -752,12 +684,7 @@ let rec select_and_propagate_all : _ -> type_constraint selector_input list -> s * constraints : constraints TypeVariableMap.t ; * } *) -type state = { - structured_dbs : structured_dbs ; - already_selected : already_selected ; -} - -let initial_state : state = (* { +let initial_state : typer_state = (* { * unification_vars = UF.empty ; * constraints = TypeVariableMap.empty ; * assignments = TypeVariableMap.empty ; @@ -766,14 +693,14 @@ let initial_state : state = (* { structured_dbs = { all_constraints = [] ; (* type_constraint_simpl list *) - aliases = UF.empty ; (* unionfind *) - assignments = TypeVariableMap.empty; (* c_constructor_simpl TypeVariableMap.t *) - grouped_by_variable = TypeVariableMap.empty; (* constraints TypeVariableMap.t *) + aliases = UF.empty (fun s -> Format.asprintf "%a" Var.pp s) Var.compare ; (* unionfind *) + assignments = Map.create ~cmp:Var.compare; (* c_constructor_simpl TypeVariableMap.t *) + grouped_by_variable = Map.create ~cmp:Var.compare; (* constraints TypeVariableMap.t *) cycle_detection_toposort = (); (* unit *) } ; already_selected = { - break_ctor = M_break_ctor.AlreadySelected.empty ; - specialize1 = M_specialize1.AlreadySelected.empty ; + break_ctor = Set.create ~cmp:compare_output_break_ctor; + specialize1 = Set.create ~cmp:compare_output_specialize1 ; } } @@ -784,7 +711,7 @@ let initial_state : state = (* { Also, we should check at these places that we indeed do not need the state any further. Suzanne *) -let discard_state (_ : state) = () +let discard_state (_ : typer_state) = () (* let replace_var_in_state = fun (v : type_variable) (state : state) -> *) (* let aux_tv : type_expression -> _ = function *) @@ -804,7 +731,7 @@ let discard_state (_ : state) = () (* in List.map aux state *) (* This is the solver *) -let aggregate_constraints : state -> type_constraint list -> state result = fun state newc -> +let aggregate_constraints : typer_state -> type_constraint list -> typer_state result = fun state newc -> (* TODO: Iterate over constraints *) let _todo = ignore (state, newc) in let (a, b) = select_and_propagate_all state.already_selected newc state.structured_dbs in diff --git a/src/passes/8-typer-new/typer.ml b/src/passes/8-typer-new/typer.ml index 3dc3c3d39..4f3c1f77c 100644 --- a/src/passes/8-typer-new/typer.ml +++ b/src/passes/8-typer-new/typer.ml @@ -8,13 +8,14 @@ module Solver = Solver type environment = Environment.t module Errors = Errors open Errors +module Map = RedBlackTrees.PolyMap open Todo_use_fold_generator (* Extract pairs of (name,type) in the declaration and add it to the environment *) -let rec type_declaration env state : I.declaration -> (environment * Solver.state * O.declaration option) result = function +let rec type_declaration env state : I.declaration -> (environment * O.typer_state * O.declaration option) result = function | Declaration_type (type_name , type_expression) -> let%bind tv = evaluate_type env type_expression in let env' = Environment.add_type (type_name) tv env in @@ -31,7 +32,7 @@ let rec type_declaration env state : I.declaration -> (environment * Solver.stat ok (post_env, state' , Some (O.Declaration_constant { binder ; expr ; inline ; post_env} )) ) -and type_match : environment -> Solver.state -> O.type_expression -> I.matching_expr -> I.expression -> Location.t -> (O.matching_expr * Solver.state) result = +and type_match : environment -> O.typer_state -> O.type_expression -> I.matching_expr -> I.expression -> Location.t -> (O.matching_expr * O.typer_state) result = fun e state t i ae loc -> match i with | Match_bool {match_true ; match_false} -> let%bind _ = @@ -194,11 +195,11 @@ and evaluate_type (e:environment) (t:I.type_expression) : O.type_expression resu in return (T_operator (opt)) -and type_expression : environment -> Solver.state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * Solver.state) result = fun e state ?tv_opt ae -> +and type_expression : environment -> O.typer_state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * O.typer_state) result = fun e state ?tv_opt ae -> let () = ignore tv_opt in (* For compatibility with the old typer's API, this argument can be removed once the new typer is used. *) let open Solver in let module L = Logger.Stateful() in - let return : _ -> Solver.state -> _ -> _ (* return of type_expression *) = fun expr state constraints type_name -> + let return : _ -> O.typer_state -> _ -> _ (* return of type_expression *) = fun expr state constraints type_name -> let%bind new_state = aggregate_constraints state constraints in let tv = t_variable type_name () in let location = ae.location in @@ -438,8 +439,8 @@ and type_constant (name:I.constant') (lst:O.type_expression list) (tv_opt:O.type ok(name, tv) (* Apply type_declaration on every node of the AST_core from the root p *) -let type_program_returns_state ((env, state, p) : environment * Solver.state * I.program) : (environment * Solver.state * O.program) result = - let aux ((e : environment), (s : Solver.state) , (ds : O.declaration Location.wrap list)) (d:I.declaration Location.wrap) = +let type_program_returns_state ((env, state, p) : environment * O.typer_state * I.program) : (environment * O.typer_state * O.program) result = + let aux ((e : environment), (s : O.typer_state) , (ds : O.declaration Location.wrap list)) (d:I.declaration Location.wrap) = let%bind (e' , s' , d'_opt) = type_declaration e s (Location.unwrap d) in let ds' = match d'_opt with | None -> ds @@ -453,8 +454,8 @@ let type_program_returns_state ((env, state, p) : environment * Solver.state * I let declarations = List.rev declarations in (* Common hack to have O(1) append: prepend and then reverse *) ok (env', state', declarations) -let type_and_subst_xyz (env_state_node : environment * Solver.state * 'a) (apply_substs : 'b Typesystem.Misc.Substitution.Pattern.w) (type_xyz_returns_state : (environment * Solver.state * 'a) -> (environment * Solver.state * 'b) Trace.result) : ('b * Solver.state) result = - let%bind (env, state, program) = type_xyz_returns_state env_state_node in +let type_and_subst_xyz (env_state_node : environment * O.typer_state * 'a) (apply_substs : 'b Typesystem.Misc.Substitution.Pattern.w) (type_xyz_returns_state : (environment * O.typer_state * 'a) -> (environment * O.typer_state * 'b) Trace.result) : ('b * O.typer_state) result = + let%bind (env, state, node) = type_xyz_returns_state env_state_node in let subst_all = let aliases = state.structured_dbs.aliases in let assignments = state.structured_dbs.assignments in @@ -466,29 +467,29 @@ let type_and_subst_xyz (env_state_node : environment * Solver.state * 'a) (apply try Some (Solver.UF.repr variable aliases) with Not_found -> None in let%bind assignment = trace_option (simple_error (Format.asprintf "can't find assignment for root %a" Var.pp root)) @@ - (Solver.TypeVariableMap.find_opt root assignments) in - let Solver.{ tv ; c_tag ; tv_list } = assignment in + (Map.find_opt root assignments) in + let O.{ tv ; c_tag ; tv_list } = assignment in let () = ignore tv (* I think there is an issue where the tv is stored twice (as a key and in the element itself) *) in let%bind (expr : O.type_content) = Typesystem.Core.type_expression'_of_simple_c_constant (c_tag , (List.map (fun s -> O.t_variable s ()) tv_list)) in ok @@ expr in - let p = apply_substs ~substs program in + let p = apply_substs ~substs node in p in - let%bind program = subst_all in + let%bind node = subst_all in let () = ignore env in (* TODO: shouldn't we use the `env` somewhere? *) - ok (program, state) + ok (node, state) -let type_program (p : I.program) : (O.program * Solver.state) result = +let type_program (p : I.program) : (O.program * O.typer_state) result = let empty_env = DEnv.default in let empty_state = Solver.initial_state in type_and_subst_xyz (empty_env , empty_state , p) Typesystem.Misc.Substitution.Pattern.s_program type_program_returns_state -let type_expression_returns_state : (environment * Solver.state * I.expression) -> (environment * Solver.state * O.expression) Trace.result = +let type_expression_returns_state : (environment * O.typer_state * I.expression) -> (environment * O.typer_state * O.expression) Trace.result = fun (env, state, e) -> let%bind (e , state) = type_expression env state e in ok (env, state, e) -let type_expression_subst (env : environment) (state : Solver.state) ?(tv_opt : O.type_expression option) (e : I.expression) : (O.expression * Solver.state) result = +let type_expression_subst (env : environment) (state : O.typer_state) ?(tv_opt : O.type_expression option) (e : I.expression) : (O.expression * O.typer_state) result = let () = ignore tv_opt in (* For compatibility with the old typer's API, this argument can be removed once the new typer is used. *) type_and_subst_xyz (env , state , e) Typesystem.Misc.Substitution.Pattern.s_expression type_expression_returns_state @@ -496,14 +497,14 @@ let untype_type_expression = Untyper.untype_type_expression let untype_expression = Untyper.untype_expression (* These aliases are just here for quick navigation during debug, and can safely be removed later *) -let [@warning "-32"] (*rec*) type_declaration _env _state : I.declaration -> (environment * Solver.state * O.declaration option) result = type_declaration _env _state -and [@warning "-32"] type_match : environment -> Solver.state -> O.type_expression -> I.matching_expr -> I.expression -> Location.t -> (O.matching_expr * Solver.state) result = type_match +let [@warning "-32"] (*rec*) type_declaration _env _state : I.declaration -> (environment * O.typer_state * O.declaration option) result = type_declaration _env _state +and [@warning "-32"] type_match : environment -> O.typer_state -> O.type_expression -> I.matching_expr -> I.expression -> Location.t -> (O.matching_expr * O.typer_state) result = type_match and [@warning "-32"] evaluate_type (e:environment) (t:I.type_expression) : O.type_expression result = evaluate_type e t -and [@warning "-32"] type_expression : environment -> Solver.state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * Solver.state) result = type_expression +and [@warning "-32"] type_expression : environment -> O.typer_state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * O.typer_state) result = type_expression and [@warning "-32"] type_lambda e state lam = type_lambda e state lam and [@warning "-32"] type_constant (name:I.constant') (lst:O.type_expression list) (tv_opt:O.type_expression option) : (O.constant' * O.type_expression) result = type_constant name lst tv_opt -let [@warning "-32"] type_program_returns_state ((env, state, p) : environment * Solver.state * I.program) : (environment * Solver.state * O.program) result = type_program_returns_state (env, state, p) -let [@warning "-32"] type_and_subst_xyz (env_state_node : environment * Solver.state * 'a) (apply_substs : 'b Typesystem.Misc.Substitution.Pattern.w) (type_xyz_returns_state : (environment * Solver.state * 'a) -> (environment * Solver.state * 'b) Trace.result) : ('b * Solver.state) result = type_and_subst_xyz env_state_node apply_substs type_xyz_returns_state -let [@warning "-32"] type_program (p : I.program) : (O.program * Solver.state) result = type_program p -let [@warning "-32"] type_expression_returns_state : (environment * Solver.state * I.expression) -> (environment * Solver.state * O.expression) Trace.result = type_expression_returns_state -let [@warning "-32"] type_expression_subst (env : environment) (state : Solver.state) ?(tv_opt : O.type_expression option) (e : I.expression) : (O.expression * Solver.state) result = type_expression_subst env state ?tv_opt e +let [@warning "-32"] type_program_returns_state ((env, state, p) : environment * O.typer_state * I.program) : (environment * O.typer_state * O.program) result = type_program_returns_state (env, state, p) +let [@warning "-32"] type_and_subst_xyz (env_state_node : environment * O.typer_state * 'a) (apply_substs : 'b Typesystem.Misc.Substitution.Pattern.w) (type_xyz_returns_state : (environment * O.typer_state * 'a) -> (environment * O.typer_state * 'b) Trace.result) : ('b * O.typer_state) result = type_and_subst_xyz env_state_node apply_substs type_xyz_returns_state +let [@warning "-32"] type_program (p : I.program) : (O.program * O.typer_state) result = type_program p +let [@warning "-32"] type_expression_returns_state : (environment * O.typer_state * I.expression) -> (environment * O.typer_state * O.expression) Trace.result = type_expression_returns_state +let [@warning "-32"] type_expression_subst (env : environment) (state : O.typer_state) ?(tv_opt : O.type_expression option) (e : I.expression) : (O.expression * O.typer_state) result = type_expression_subst env state ?tv_opt e diff --git a/src/passes/8-typer-new/typer.mli b/src/passes/8-typer-new/typer.mli index e5b91de0a..9c2e267fc 100644 --- a/src/passes/8-typer-new/typer.mli +++ b/src/passes/8-typer-new/typer.mli @@ -38,11 +38,11 @@ module Errors : sig *) end -val type_program : I.program -> (O.program * Solver.state) result -val type_declaration : environment -> Solver.state -> I.declaration -> (environment * Solver.state * O.declaration option) result +val type_program : I.program -> (O.program * O.typer_state) result +val type_declaration : environment -> O.typer_state -> I.declaration -> (environment * O.typer_state * O.declaration option) result val evaluate_type : environment -> I.type_expression -> O.type_expression result -val type_expression : environment -> Solver.state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * Solver.state) result -val type_expression_subst : environment -> Solver.state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * Solver.state) result +val type_expression : environment -> O.typer_state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * O.typer_state) result +val type_expression_subst : environment -> O.typer_state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * O.typer_state) result val type_constant : I.constant' -> O.type_expression list -> O.type_expression option -> (O.constant' * O.type_expression) result val untype_type_expression : O.type_expression -> I.type_expression result diff --git a/src/passes/8-typer-new/wrap.ml b/src/passes/8-typer-new/wrap.ml index 525f2dfd8..51dcc794f 100644 --- a/src/passes/8-typer-new/wrap.ml +++ b/src/passes/8-typer-new/wrap.ml @@ -1,4 +1,5 @@ open Trace +open Ast_typed.Misc module Core = Typesystem.Core module I = Ast_core @@ -35,17 +36,17 @@ let rec type_expression_to_type_value : T.type_expression -> O.type_value = fun | T_sum kvmap -> let () = failwith "fixme: don't use to_list, it drops the variant keys, rows have a differnt kind than argument lists for now!" in let tlist = List.map (fun ({ctor_type;_}:T.ctor_content) -> ctor_type) (T.CMap.to_list kvmap) in - P_constant (C_variant, List.map type_expression_to_type_value tlist) + p_constant C_variant (List.map type_expression_to_type_value tlist) | T_record kvmap -> let () = failwith "fixme: don't use to_list, it drops the record keys, rows have a differnt kind than argument lists for now!" in let tlist = List.map (fun ({field_type;_}:T.field_content) -> field_type) (T.LMap.to_list kvmap) in - P_constant (C_record, List.map type_expression_to_type_value tlist) + p_constant C_record (List.map type_expression_to_type_value tlist) | T_arrow {type1;type2} -> - P_constant (C_arrow, List.map type_expression_to_type_value [ type1 ; type2 ]) + p_constant C_arrow (List.map type_expression_to_type_value [ type1 ; type2 ]) | T_variable (type_name) -> P_variable type_name | T_constant (type_name) -> - let csttag = Core.(match type_name with + let csttag = T.(match type_name with | TC_unit -> C_unit | TC_string -> C_string | TC_nat -> C_nat @@ -62,9 +63,9 @@ let rec type_expression_to_type_value : T.type_expression -> O.type_value = fun | TC_void -> C_unit (* TODO : replace with void *) ) in - P_constant (csttag, []) + p_constant csttag [] | T_operator (type_operator) -> - let (csttag, args) = Core.(match type_operator with + let (csttag, args) = T.(match type_operator with | TC_option o -> (C_option, [o]) | TC_set s -> (C_set, [s]) | TC_map { k ; v } -> (C_map, [k;v]) @@ -75,30 +76,30 @@ let rec type_expression_to_type_value : T.type_expression -> O.type_value = fun | TC_contract c -> (C_contract, [c]) ) in - P_constant (csttag, List.map type_expression_to_type_value args) + p_constant csttag (List.map type_expression_to_type_value args) let rec type_expression_to_type_value_copypasted : I.type_expression -> O.type_value = fun te -> match te.type_content with | T_sum kvmap -> let () = failwith "fixme: don't use to_list, it drops the variant keys, rows have a differnt kind than argument lists for now!" in let tlist = List.map (fun ({ctor_type;_}:I.ctor_content) -> ctor_type) (I.CMap.to_list kvmap) in - P_constant (C_variant, List.map type_expression_to_type_value_copypasted tlist) + p_constant C_variant (List.map type_expression_to_type_value_copypasted tlist) | T_record kvmap -> let () = failwith "fixme: don't use to_list, it drops the record keys, rows have a differnt kind than argument lists for now!" in let tlist = List.map (fun ({field_type;_}:I.field_content) -> field_type) (I.LMap.to_list kvmap) in - P_constant (C_record, List.map type_expression_to_type_value_copypasted tlist) + p_constant C_record (List.map type_expression_to_type_value_copypasted tlist) | T_arrow {type1;type2} -> - P_constant (C_arrow, List.map type_expression_to_type_value_copypasted [ type1 ; type2 ]) + p_constant C_arrow (List.map type_expression_to_type_value_copypasted [ type1 ; type2 ]) | T_variable type_name -> P_variable (type_name) (* eird stuff*) | T_constant (type_name) -> - let csttag = Core.(match type_name with + let csttag = T.(match type_name with | TC_unit -> C_unit | TC_string -> C_string | _ -> failwith "unknown type constructor") in - P_constant (csttag,[]) + p_constant csttag [] | T_operator (type_name) -> - let (csttag, args) = Core.(match type_name with + let (csttag, args) = T.(match type_name with | TC_option o -> (C_option , [o]) | TC_list l -> (C_list , [l]) | TC_set s -> (C_set , [s]) @@ -109,7 +110,7 @@ let rec type_expression_to_type_value_copypasted : I.type_expression -> O.type_v | TC_arrow ( arg , ret ) -> (C_arrow, [ arg ; ret ]) ) in - P_constant (csttag, List.map type_expression_to_type_value_copypasted args) + p_constant csttag (List.map type_expression_to_type_value_copypasted args) let failwith_ : unit -> (constraints * O.type_variable) = fun () -> let type_name = Core.fresh_type_variable () in @@ -118,12 +119,12 @@ let failwith_ : unit -> (constraints * O.type_variable) = fun () -> let variable : I.expression_variable -> T.type_expression -> (constraints * T.type_variable) = fun _name expr -> let pattern = type_expression_to_type_value expr in let type_name = Core.fresh_type_variable () in - [C_equation (P_variable (type_name) , pattern)] , type_name + [C_equation { aval = P_variable type_name ; bval = pattern }] , type_name let literal : T.type_expression -> (constraints * T.type_variable) = fun t -> let pattern = type_expression_to_type_value t in let type_name = Core.fresh_type_variable () in - [C_equation (P_variable (type_name) , pattern)] , type_name + [C_equation { aval = P_variable type_name ; bval = pattern }] , type_name (* let literal_bool : unit -> (constraints * O.type_variable) = fun () -> @@ -139,9 +140,9 @@ let literal : T.type_expression -> (constraints * T.type_variable) = fun t -> let tuple : T.type_expression list -> (constraints * T.type_variable) = fun tys -> let patterns = List.map type_expression_to_type_value tys in - let pattern = O.(P_constant (C_record , patterns)) in + let pattern = p_constant C_record patterns in let type_name = Core.fresh_type_variable () in - [C_equation (P_variable (type_name) , pattern)] , type_name + [C_equation { aval = P_variable type_name ; bval = pattern}] , type_name (* let t_tuple = ('label:int, 'v) … -> record ('label : 'v) … *) (* let t_constructor = ('label:string, 'v) -> variant ('label : 'v) *) @@ -170,8 +171,9 @@ end let access_label ~(base : T.type_expression) ~(label : O.accessor) : (constraints * T.type_variable) = let base' = type_expression_to_type_value base in let expr_type = Core.fresh_type_variable () in - [O.C_access_label (base' , label , expr_type)] , expr_type + [T.C_access_label { c_access_label_tval = base' ; accessor = label ; c_access_label_tvar = expr_type }] , expr_type +open Ast_typed.Misc let constructor : T.type_expression -> T.type_expression -> T.type_expression -> (constraints * T.type_variable) = fun t_arg c_arg sum -> @@ -180,64 +182,64 @@ let constructor let sum = type_expression_to_type_value sum in let whole_expr = Core.fresh_type_variable () in [ - C_equation (P_variable (whole_expr) , sum) ; - C_equation (t_arg , c_arg) + c_equation (P_variable whole_expr) sum ; + c_equation t_arg c_arg ; ] , whole_expr let record : T.field_content T.label_map -> (constraints * T.type_variable) = fun fields -> let record_type = type_expression_to_type_value (T.t_record fields ()) in let whole_expr = Core.fresh_type_variable () in - [C_equation (P_variable whole_expr , record_type)] , whole_expr + [c_equation (P_variable whole_expr) record_type] , whole_expr let collection : O.constant_tag -> T.type_expression list -> (constraints * T.type_variable) = fun ctor element_tys -> - let elttype = O.P_variable (Core.fresh_type_variable ()) in + let elttype = T.P_variable (Core.fresh_type_variable ()) in let aux elt = let elt' = type_expression_to_type_value elt - in O.C_equation (elttype , elt') in + in c_equation elttype elt' in let equations = List.map aux element_tys in let whole_expr = Core.fresh_type_variable () in - O.[ - C_equation (P_variable whole_expr , O.P_constant (ctor , [elttype])) + [ + c_equation (P_variable whole_expr) (p_constant ctor [elttype]) ; ] @ equations , whole_expr -let list = collection O.C_list -let set = collection O.C_set +let list = collection T.C_list +let set = collection T.C_set let map : (T.type_expression * T.type_expression) list -> (constraints * T.type_variable) = fun kv_tys -> - let k_type = O.P_variable (Core.fresh_type_variable ()) in - let v_type = O.P_variable (Core.fresh_type_variable ()) in + let k_type = T.P_variable (Core.fresh_type_variable ()) in + let v_type = T.P_variable (Core.fresh_type_variable ()) in let aux_k (k , _v) = let k' = type_expression_to_type_value k in - O.C_equation (k_type , k') in + c_equation k_type k' in let aux_v (_k , v) = let v' = type_expression_to_type_value v in - O.C_equation (v_type , v') in + c_equation v_type v' in let equations_k = List.map aux_k kv_tys in let equations_v = List.map aux_v kv_tys in let whole_expr = Core.fresh_type_variable () in - O.[ - C_equation (P_variable whole_expr , O.P_constant (C_map , [k_type ; v_type])) + [ + c_equation (P_variable whole_expr) (p_constant C_map [k_type ; v_type]) ; ] @ equations_k @ equations_v , whole_expr let big_map : (T.type_expression * T.type_expression) list -> (constraints * T.type_variable) = fun kv_tys -> - let k_type = O.P_variable (Core.fresh_type_variable ()) in - let v_type = O.P_variable (Core.fresh_type_variable ()) in + let k_type = T.P_variable (Core.fresh_type_variable ()) in + let v_type = T.P_variable (Core.fresh_type_variable ()) in let aux_k (k , _v) = let k' = type_expression_to_type_value k in - O.C_equation (k_type , k') in + c_equation k_type k' in let aux_v (_k , v) = let v' = type_expression_to_type_value v in - O.C_equation (v_type , v') in + c_equation v_type v' in let equations_k = List.map aux_k kv_tys in let equations_v = List.map aux_v kv_tys in let whole_expr = Core.fresh_type_variable () in - O.[ + [ (* TODO: this doesn't tag big_maps uniquely (i.e. if two big_map have the same type, they can be swapped. *) - C_equation (P_variable whole_expr , O.P_constant (C_big_map , [k_type ; v_type])) + c_equation (P_variable whole_expr) (p_constant C_big_map [k_type ; v_type]) ; ] @ equations_k @ equations_v , whole_expr let application : T.type_expression -> T.type_expression -> (constraints * T.type_variable) = @@ -245,8 +247,8 @@ let application : T.type_expression -> T.type_expression -> (constraints * T.typ let whole_expr = Core.fresh_type_variable () in let f' = type_expression_to_type_value f in let arg' = type_expression_to_type_value arg in - O.[ - C_equation (f' , P_constant (C_arrow , [arg' ; P_variable whole_expr])) + [ + c_equation f' (p_constant C_arrow [arg' ; P_variable whole_expr]) ; ] , whole_expr let look_up : T.type_expression -> T.type_expression -> (constraints * T.type_variable) = @@ -255,9 +257,9 @@ let look_up : T.type_expression -> T.type_expression -> (constraints * T.type_va let ind' = type_expression_to_type_value ind in let whole_expr = Core.fresh_type_variable () in let v = Core.fresh_type_variable () in - O.[ - C_equation (ds' , P_constant (C_map, [ind' ; P_variable v])) ; - C_equation (P_variable whole_expr , P_constant (C_option , [P_variable v])) + [ + c_equation ds' (p_constant C_map [ind' ; P_variable v]) ; + c_equation (P_variable whole_expr) (p_constant C_option [P_variable v]) ; ] , whole_expr let sequence : T.type_expression -> T.type_expression -> (constraints * T.type_variable) = @@ -265,9 +267,9 @@ let sequence : T.type_expression -> T.type_expression -> (constraints * T.type_v let a' = type_expression_to_type_value a in let b' = type_expression_to_type_value b in let whole_expr = Core.fresh_type_variable () in - O.[ - C_equation (a' , P_constant (C_unit , [])) ; - C_equation (b' , P_variable whole_expr) + [ + c_equation a' (p_constant C_unit []) ; + c_equation b' (P_variable whole_expr) ; ] , whole_expr let loop : T.type_expression -> T.type_expression -> (constraints * T.type_variable) = @@ -275,10 +277,10 @@ let loop : T.type_expression -> T.type_expression -> (constraints * T.type_varia let expr' = type_expression_to_type_value expr in let body' = type_expression_to_type_value body in let whole_expr = Core.fresh_type_variable () in - O.[ - C_equation (expr' , P_variable (Stage_common.Constant.t_bool)) ; - C_equation (body' , P_constant (C_unit , [])) ; - C_equation (P_variable whole_expr , P_constant (C_unit , [])) + [ + c_equation expr' (P_variable (Stage_common.Constant.t_bool)) ; + c_equation body' (p_constant C_unit []) ; + c_equation (P_variable whole_expr) (p_constant C_unit []) ] , whole_expr let let_in : T.type_expression -> T.type_expression option -> T.type_expression -> (constraints * T.type_variable) = @@ -287,18 +289,18 @@ let let_in : T.type_expression -> T.type_expression option -> T.type_expression let result' = type_expression_to_type_value result in let rhs_tv_opt' = match rhs_tv_opt with None -> [] - | Some annot -> O.[C_equation (rhs' , type_expression_to_type_value annot)] in + | Some annot -> [c_equation rhs' (type_expression_to_type_value annot)] in let whole_expr = Core.fresh_type_variable () in - O.[ - C_equation (result' , P_variable whole_expr) + [ + c_equation result' (P_variable whole_expr) ; ] @ rhs_tv_opt', whole_expr let recursive : T.type_expression -> (constraints * T.type_variable) = fun fun_type -> let fun_type = type_expression_to_type_value fun_type in let whole_expr = Core.fresh_type_variable () in - O.[ - C_equation (fun_type, P_variable whole_expr) + [ + c_equation fun_type (P_variable whole_expr) ; ], whole_expr let assign : T.type_expression -> T.type_expression -> (constraints * T.type_variable) = @@ -306,9 +308,9 @@ let assign : T.type_expression -> T.type_expression -> (constraints * T.type_var let v' = type_expression_to_type_value v in let e' = type_expression_to_type_value e in let whole_expr = Core.fresh_type_variable () in - O.[ - C_equation (v' , e') ; - C_equation (P_variable whole_expr , P_constant (C_unit , [])) + [ + c_equation v' e' ; + c_equation (P_variable whole_expr) (p_constant C_unit []) ; ] , whole_expr let annotation : T.type_expression -> T.type_expression -> (constraints * T.type_variable) = @@ -316,16 +318,16 @@ let annotation : T.type_expression -> T.type_expression -> (constraints * T.type let e' = type_expression_to_type_value e in let annot' = type_expression_to_type_value annot in let whole_expr = Core.fresh_type_variable () in - O.[ - C_equation (e' , annot') ; - C_equation (e' , P_variable whole_expr) + [ + c_equation e' annot' ; + c_equation e' (P_variable whole_expr) ; ] , whole_expr let matching : T.type_expression list -> (constraints * T.type_variable) = fun es -> let whole_expr = Core.fresh_type_variable () in let type_expressions = (List.map type_expression_to_type_value es) in - let cs = List.map (fun e -> O.C_equation (P_variable whole_expr , e)) type_expressions + let cs = List.map (fun e -> c_equation (P_variable whole_expr) e) type_expressions in cs, whole_expr let fresh_binder () = @@ -342,15 +344,15 @@ let lambda let unification_body = Core.fresh_type_variable () in let arg' = match arg with None -> [] - | Some arg -> O.[C_equation (P_variable unification_arg , type_expression_to_type_value arg)] in + | Some arg -> [c_equation (P_variable unification_arg) (type_expression_to_type_value arg)] in let body' = match body with None -> [] - | Some body -> O.[C_equation (P_variable unification_body , type_expression_to_type_value body)] - in O.[ - C_equation (type_expression_to_type_value fresh , P_variable unification_arg) ; - C_equation (P_variable whole_expr , - P_constant (C_arrow , [P_variable unification_arg ; - P_variable unification_body])) + | Some body -> [c_equation (P_variable unification_body) (type_expression_to_type_value body)] + in [ + c_equation (type_expression_to_type_value fresh) (P_variable unification_arg) ; + c_equation (P_variable whole_expr) + (p_constant C_arrow ([P_variable unification_arg ; + P_variable unification_body])) ] @ arg' @ body' , whole_expr (* This is pretty much a wrapper for an n-ary function. *) @@ -358,7 +360,7 @@ let constant : O.type_value -> T.type_expression list -> (constraints * T.type_v fun f args -> let whole_expr = Core.fresh_type_variable () in let args' = List.map type_expression_to_type_value args in - let args_tuple = O.P_constant (C_record , args') in - O.[ - C_equation (f , P_constant (C_arrow , [args_tuple ; P_variable whole_expr])) + let args_tuple = p_constant C_record args' in + [ + c_equation f (p_constant C_arrow ([args_tuple ; P_variable whole_expr])) ] , whole_expr diff --git a/src/passes/8-typer-old/typer.ml b/src/passes/8-typer-old/typer.ml index bf5bdeb1b..3e3d0b646 100644 --- a/src/passes/8-typer-old/typer.ml +++ b/src/passes/8-typer-old/typer.ml @@ -466,7 +466,7 @@ let unconvert_constant' : O.constant' -> I.constant' = function | C_SET_DELEGATE -> C_SET_DELEGATE | C_CREATE_CONTRACT -> C_CREATE_CONTRACT -let rec type_program (p:I.program) : (O.program * Solver.state) result = +let rec type_program (p:I.program) : (O.program * O.typer_state) result = let aux (e, acc:(environment * O.declaration Location.wrap list)) (d:I.declaration Location.wrap) = let%bind ed' = (bind_map_location (type_declaration e (Solver.placeholder_for_state_of_new_typer ()))) d in let loc : 'a . 'a Location.wrap -> _ -> _ = fun x v -> Location.wrap ~loc:x.location v in @@ -480,7 +480,7 @@ let rec type_program (p:I.program) : (O.program * Solver.state) result = bind_fold_list aux (DEnv.default, []) p in ok @@ (List.rev lst , (Solver.placeholder_for_state_of_new_typer ())) -and type_declaration env (_placeholder_for_state_of_new_typer : Solver.state) : I.declaration -> (environment * Solver.state * O.declaration option) result = function +and type_declaration env (_placeholder_for_state_of_new_typer : O.typer_state) : I.declaration -> (environment * O.typer_state * O.declaration option) result = function | Declaration_type (type_name , type_expression) -> let%bind tv = evaluate_type env type_expression in let env' = Environment.add_type (type_name) tv env in @@ -659,7 +659,7 @@ and evaluate_type (e:environment) (t:I.type_expression) : O.type_expression resu in return (T_operator (opt)) -and type_expression : environment -> Solver.state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * Solver.state) result +and type_expression : environment -> O.typer_state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * O.typer_state) result = fun e _placeholder_for_state_of_new_typer ?tv_opt ae -> let%bind res = type_expression' e ?tv_opt ae in ok (res, (Solver.placeholder_for_state_of_new_typer ())) diff --git a/src/passes/8-typer-old/typer.mli b/src/passes/8-typer-old/typer.mli index d1bf21393..ff7009a8c 100644 --- a/src/passes/8-typer-old/typer.mli +++ b/src/passes/8-typer-old/typer.mli @@ -38,11 +38,11 @@ module Errors : sig *) end -val type_program : I.program -> (O.program * Solver.state) result -val type_declaration : environment -> Solver.state -> I.declaration -> (environment * Solver.state * O.declaration option) result +val type_program : I.program -> (O.program * O.typer_state) result +val type_declaration : environment -> O.typer_state -> I.declaration -> (environment * O.typer_state * O.declaration option) result (* val type_match : (environment -> 'i -> 'o result) -> environment -> O.type_value -> 'i I.matching -> I.expression -> Location.t -> 'o O.matching result *) val evaluate_type : environment -> I.type_expression -> O.type_expression result -val type_expression : environment -> Solver.state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * Solver.state) result +val type_expression : environment -> O.typer_state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * O.typer_state) result val type_constant : I.constant' -> O.type_expression list -> O.type_expression option -> (O.constant' * O.type_expression) result (* val untype_type_value : O.type_value -> (I.type_expression) result diff --git a/src/passes/8-typer/typer.mli b/src/passes/8-typer/typer.mli index bf4c11f4d..8069ab943 100644 --- a/src/passes/8-typer/typer.mli +++ b/src/passes/8-typer/typer.mli @@ -11,6 +11,6 @@ module Solver = Typer_new.Solver type environment = Environment.t -val type_program : I.program -> (O.program * Solver.state) result -val type_expression_subst : environment -> Solver.state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * Solver.state) result +val type_program : I.program -> (O.program * O.typer_state) result +val type_expression_subst : environment -> O.typer_state -> ?tv_opt:O.type_expression -> I.expression -> (O.expression * O.typer_state) result val untype_expression : O.expression -> I.expression result diff --git a/src/stages/4-ast_typed/.gitignore b/src/stages/4-ast_typed/.gitignore new file mode 100644 index 000000000..39f5407d5 --- /dev/null +++ b/src/stages/4-ast_typed/.gitignore @@ -0,0 +1,2 @@ +/generated_fold.ml + diff --git a/src/stages/4-ast_typed/PP_generic.ml b/src/stages/4-ast_typed/PP_generic.ml index d698c70e7..fc9bbeb9d 100644 --- a/src/stages/4-ast_typed/PP_generic.ml +++ b/src/stages/4-ast_typed/PP_generic.ml @@ -15,6 +15,7 @@ let needs_parens = { z = (fun _ _ _ -> false) ; string = (fun _ _ _ -> false) ; bytes = (fun _ _ _ -> false) ; + unit = (fun _ _ _ -> false) ; packed_internal_operation = (fun _ _ _ -> false) ; expression_variable = (fun _ _ _ -> false) ; constructor' = (fun _ _ _ -> false) ; @@ -28,6 +29,9 @@ let needs_parens = { list_ne = (fun _ _ _ _ -> false) ; option = (fun _visitor _continue _state o -> match o with None -> false | Some _ -> true) ; + poly_unionfind = (fun _ _ _ _ -> false) ; + poly_set = (fun _ _ _ _ -> false) ; + typeVariableMap = (fun _ _ _ _ -> false) ; } let op ppf = { @@ -49,6 +53,7 @@ let op ppf = { z = (fun _visitor () i -> fprintf ppf "%a" Z.pp_print i) ; string = (fun _visitor () str -> fprintf ppf "\"%s\"" str) ; bytes = (fun _visitor () _bytes -> fprintf ppf "bytes...") ; + unit = (fun _visitor () () -> fprintf ppf "()") ; packed_internal_operation = (fun _visitor () _op -> fprintf ppf "Operation(...bytes)") ; expression_variable = (fun _visitor () ev -> fprintf ppf "%a" Var.pp ev) ; constructor' = (fun _visitor () (Constructor c) -> fprintf ppf "Constructor %s" c) ; @@ -80,6 +85,17 @@ let op ppf = { match o with | None -> fprintf ppf "None" | Some v -> fprintf ppf "%a" (fun _ppf -> continue ()) v) ; + poly_unionfind = (fun _visitor continue () p -> + let lst = (UnionFind.Poly2.elements p) in + fprintf ppf "LMap [ %a ]" (list_sep (fun _ppf -> continue ()) (fun ppf () -> fprintf ppf " ; ")) lst); + poly_set = (fun _visitor continue () set -> + let lst = (RedBlackTrees.PolySet.elements set) in + fprintf ppf "LMap [ %a ]" (list_sep (fun _ppf -> continue ()) (fun ppf () -> fprintf ppf " ; ")) lst); + typeVariableMap = (fun _visitor continue () tvmap -> + let lst = List.sort (fun (a, _) (b, _) -> Var.compare a b) (RedBlackTrees.PolyMap.bindings tvmap) in + let aux ppf (k, v) = + fprintf ppf "(Var %a, %a)" Var.pp k (fun _ppf -> continue ()) v in + fprintf ppf "typeVariableMap [ %a ]" (list_sep aux (fun ppf () -> fprintf ppf " ; ")) lst); } let print : (unit fold_config -> unit -> 'a -> unit) -> formatter -> 'a -> unit = fun fold ppf v -> @@ -87,3 +103,4 @@ let print : (unit fold_config -> unit -> 'a -> unit) -> formatter -> 'a -> unit let program = print fold__program let type_expression = print fold__type_expression +let full_environment = print fold__full_environment diff --git a/src/stages/4-ast_typed/dune b/src/stages/4-ast_typed/dune index 7a16fdd2a..370845a60 100644 --- a/src/stages/4-ast_typed/dune +++ b/src/stages/4-ast_typed/dune @@ -14,6 +14,7 @@ ast_core ; Is that a good idea? stage_common adt_generator + UnionFind ) (preprocess (pps ppx_let bisect_ppx --conditional) diff --git a/src/stages/4-ast_typed/misc.ml b/src/stages/4-ast_typed/misc.ml index 1ab73cd5d..990c53288 100644 --- a/src/stages/4-ast_typed/misc.ml +++ b/src/stages/4-ast_typed/misc.ml @@ -526,6 +526,14 @@ let program_environment (program : program) : full_environment = | Declaration_constant { binder=_ ; expr=_ ; inline=_ ; post_env } -> post_env let equal_variables a b : bool = - match a.expression_content, b.expression_content with + match a.expression_content, b.expression_content with | E_variable a, E_variable b -> Var.equal a b | _, _ -> false + +let p_constant (p_ctor_tag : constant_tag) (p_ctor_args : p_ctor_args) = + P_constant { + p_ctor_tag : constant_tag ; + p_ctor_args : p_ctor_args ; + } + +let c_equation aval bval = C_equation { aval ; bval } diff --git a/src/stages/4-ast_typed/misc.mli b/src/stages/4-ast_typed/misc.mli index 924702ce8..fae2a1a36 100644 --- a/src/stages/4-ast_typed/misc.mli +++ b/src/stages/4-ast_typed/misc.mli @@ -71,3 +71,6 @@ val assert_literal_eq : ( literal * literal ) -> unit result val get_entry : program -> string -> expression result val program_environment : program -> full_environment + +val p_constant : constant_tag -> p_ctor_args -> type_value +val c_equation : type_value -> type_value -> type_constraint diff --git a/src/stages/4-ast_typed/types.ml b/src/stages/4-ast_typed/types.ml index ab7807765..e06b7ccd2 100644 --- a/src/stages/4-ast_typed/types.ml +++ b/src/stages/4-ast_typed/types.ml @@ -423,3 +423,190 @@ and named_type_content = { type_name : type_variable; type_value : type_expression; } + + + + + +(* Solver types *) + +(* typevariable: to_string = (fun s -> Format.asprintf "%a" Var.pp s) *) +type unionfind = type_variable poly_unionfind + +(* core *) + +(* add information on the type or the kind for operator *) +type constant_tag = + | C_arrow (* * -> * -> * isn't this wrong? *) + | C_option (* * -> * *) + | C_record (* ( label , * ) … -> * *) + | C_variant (* ( label , * ) … -> * *) + | C_map (* * -> * -> * *) + | C_big_map (* * -> * -> * *) + | C_list (* * -> * *) + | C_set (* * -> * *) + | C_unit (* * *) + | C_string (* * *) + | C_nat (* * *) + | C_mutez (* * *) + | C_timestamp (* * *) + | C_int (* * *) + | C_address (* * *) + | C_bytes (* * *) + | C_key_hash (* * *) + | C_key (* * *) + | C_signature (* * *) + | C_operation (* * *) + | C_contract (* * -> * *) + | C_chain_id (* * *) + +(* TODO: rename to type_expression or something similar (it includes variables, and unevaluated functions + applications *) +type type_value = + | P_forall of p_forall + | P_variable of type_variable + | P_constant of p_constant + | P_apply of p_apply + +and p_apply = { + tf : type_value ; + targ : type_value ; +} +and p_ctor_args = type_value list +and p_constant = { + p_ctor_tag : constant_tag ; + p_ctor_args : p_ctor_args ; + } +and p_constraints = type_constraint list +and p_forall = { + binder : type_variable ; + constraints : p_constraints ; + body : type_value ; +} + +(* Different type of constraint *) +and ctor_args = type_variable list (* non-empty list *) +and simple_c_constructor = { + ctor_tag : constant_tag ; + ctor_args : ctor_args ; + } +and simple_c_constant = { + constant_tag: constant_tag ; (* for type constructors that do not take arguments *) + } +and c_const = { + c_const_tvar : type_variable ; + c_const_tval : type_value ; + } +and c_equation = { + aval : type_value ; + bval : type_value ; +} +and tc_args = type_value list +and c_typeclass = { + tc_args : tc_args ; + typeclass : typeclass ; +} +and c_access_label = { + c_access_label_tval : type_value ; + accessor : label ; + c_access_label_tvar : type_variable ; + } + +(*What i was saying just before *) +and type_constraint = + (* | C_assignment of (type_variable * type_pattern) *) + | C_equation of c_equation (* TVA = TVB *) + | C_typeclass of c_typeclass (* TVL ∈ TVLs, for now in extension, later add intensional (rule-based system for inclusion in the typeclass) *) + | C_access_label of c_access_label (* poor man's type-level computation to ensure that TV.label is type_variable *) +(* | … *) + +(* is the first list in case on of the type of the type class as a kind *->*->* ? *) +and tc_allowed = type_value list +and typeclass = tc_allowed list + +(* end core *) + +type c_constructor_simpl_typeVariableMap = c_constructor_simpl typeVariableMap +and constraints_typeVariableMap = constraints typeVariableMap +and type_constraint_simpl_list = type_constraint_simpl list +and structured_dbs = { + all_constraints : type_constraint_simpl_list ; + aliases : unionfind ; + (* assignments (passive data structure). *) + (* Now : just a map from unification vars to types (pb: what about partial types?) *) + (* maybe just local assignments (allow only vars as children of pair(α,β)) *) + (* TODO : the rhs of the map should not repeat the variable name. *) + assignments : c_constructor_simpl_typeVariableMap ; + grouped_by_variable : constraints_typeVariableMap ; (* map from (unionfind) variables to constraints containing them *) + cycle_detection_toposort : unit ; (* example of structured db that we'll add later *) +} + +and c_constructor_simpl_list = c_constructor_simpl list +and c_poly_simpl_list = c_poly_simpl list +and c_typeclass_simpl_list = c_typeclass_simpl list +and constraints = { + (* If implemented in a language with decent sets, these should be sets not lists. *) + constructor : c_constructor_simpl_list ; (* List of ('a = constructor(args…)) constraints *) + poly : c_poly_simpl_list ; (* List of ('a = forall 'b, some_type) constraints *) + tc : c_typeclass_simpl_list ; (* List of (typeclass(args…)) constraints *) +} +and type_variable_list = type_variable list +and c_constructor_simpl = { + tv : type_variable; + c_tag : constant_tag; + tv_list : type_variable_list; +} +and c_const_e = { + c_const_e_tv : type_variable ; + c_const_e_te : type_expression ; + } +and c_equation_e = { + aex : type_expression ; + bex : type_expression ; + } +and c_typeclass_simpl = { + tc : typeclass ; + args : type_variable_list ; +} +and c_poly_simpl = { + tv : type_variable ; + forall : p_forall ; +} +and type_constraint_simpl = + | SC_Constructor of c_constructor_simpl (* α = ctor(β, …) *) + | SC_Alias of c_alias (* α = β *) + | SC_Poly of c_poly_simpl (* α = forall β, δ where δ can be a more complex type *) + | SC_Typeclass of c_typeclass_simpl (* TC(α, …) *) + +and c_alias = { + a : type_variable ; + b : type_variable ; + } + + +(* sub-sub component: lazy selector (don't re-try all selectors every time) *) +(* For now: just re-try everytime *) + +(* selector / propagation rule for breaking down composite types *) +(* For now: break pair(a, b) = pair(c, d) into a = c, b = d *) +type output_break_ctor = { + a_k_var : c_constructor_simpl ; + a_k'_var' : c_constructor_simpl ; + } + +type output_specialize1 = { + poly : c_poly_simpl ; + a_k_var : c_constructor_simpl ; + } + +type m_break_ctor__already_selected = output_break_ctor poly_set +type m_specialize1__already_selected = output_specialize1 poly_set + +type already_selected = { + break_ctor : m_break_ctor__already_selected ; + specialize1 : m_specialize1__already_selected ; +} + +type typer_state = { + structured_dbs : structured_dbs ; + already_selected : already_selected ; +} diff --git a/src/stages/4-ast_typed/types_utils.ml b/src/stages/4-ast_typed/types_utils.ml index 34e7c5668..b9367fa0c 100644 --- a/src/stages/4-ast_typed/types_utils.ml +++ b/src/stages/4-ast_typed/types_utils.ml @@ -77,3 +77,48 @@ let fold_map__option : type a state new_a . (state -> a -> (state * new_a) resul match o with | None -> ok (state, None) | Some v -> let%bind state, v = f state v in ok (state, Some v) + + + + + +(* Solver types *) + +type 'a poly_unionfind = 'a UnionFind.Poly2.t + +(* typevariable: to_string = (fun s -> Format.asprintf "%a" Var.pp s) *) +(* representant for an equivalence class of type variables *) +type 'v typeVariableMap = (type_variable, 'v) RedBlackTrees.PolyMap.t + +type 'a poly_set = 'a RedBlackTrees.PolySet.t + +let fold_map__poly_unionfind : type a state new_a . (state -> a -> (state * new_a) result) -> state -> a poly_unionfind -> (state * new_a poly_unionfind) Simple_utils.Trace.result = + fun f state l -> + ignore (f, state, l) ; failwith "TODO + let aux acc element = + let%bind state , l = acc in + let%bind (state , new_element) = f state element in ok (state , new_element :: l) in + let%bind (state , l) = List.fold_left aux (ok (state , [])) l in + ok (state , l)" + +let fold_map__PolyMap : type k v state new_v . (state -> v -> (state * new_v) result) -> state -> (k, v) PolyMap.t -> (state * (k, new_v) PolyMap.t) result = + fun f state m -> + let aux k v ~acc = + let%bind (state , m) = acc in + let%bind (state , new_v) = f state v in + ok (state , PolyMap.add k new_v m) in + let%bind (state , m) = PolyMap.fold_inc aux m ~init:(ok (state, PolyMap.empty m)) in + ok (state , m) + +let fold_map__typeVariableMap : type a state new_a . (state -> a -> (state * new_a) result) -> state -> a typeVariableMap -> (state * new_a typeVariableMap) result = + fold_map__PolyMap + +let fold_map__poly_set : type a state new_a . (state -> a -> (state * new_a) result) -> state -> a poly_set -> (state * new_a poly_set) result = + fun f state s -> + let new_compare : (new_a -> new_a -> int) = failwith "TODO: thread enough information about the target AST so that we may compare things here." in + let aux elt ~acc = + let%bind (state , s) = acc in + let%bind (state , new_elt) = f state elt in + ok (state , PolySet.add new_elt s) in + let%bind (state , m) = PolySet.fold_inc aux s ~init:(ok (state, PolySet.create ~cmp:new_compare)) in + ok (state , m) diff --git a/src/stages/adt_generator/adt_generator.ml b/src/stages/adt_generator/adt_generator.ml index f96857f7b..11f617517 100644 --- a/src/stages/adt_generator/adt_generator.ml +++ b/src/stages/adt_generator/adt_generator.ml @@ -1 +1,2 @@ module Generic = Generic +module Common = Common diff --git a/src/stages/adt_generator/common.ml b/src/stages/adt_generator/common.ml new file mode 100644 index 000000000..890711eb9 --- /dev/null +++ b/src/stages/adt_generator/common.ml @@ -0,0 +1,3 @@ +type ('a,'err) monad = ('a) Simple_utils.Trace.result;; +let (>>?) v f = Simple_utils.Trace.bind f v;; +let return v = Simple_utils.Trace.ok v;; diff --git a/src/stages/adt_generator/dune b/src/stages/adt_generator/dune index e9f2660b3..5e98e3845 100644 --- a/src/stages/adt_generator/dune +++ b/src/stages/adt_generator/dune @@ -1,5 +1,7 @@ (library (name adt_generator) (public_name ligo.adt_generator) - (libraries) + (libraries + simple-utils + ) ) diff --git a/src/stages/adt_generator/generator.raku b/src/stages/adt_generator/generator.raku index 11d0c5e91..725b59415 100644 --- a/src/stages/adt_generator/generator.raku +++ b/src/stages/adt_generator/generator.raku @@ -3,6 +3,11 @@ use v6.c; use strict; use worries; +# TODO: find a way to do mutual recursion between the produced file and some #include-y-thingy +# TODO: make an .mli +# TODO: shorthand for `foo list` etc. in field and constructor types +# TODO: error when reserved names are used ("state", … please list them here) + my $moduleName = @*ARGS[0].subst(/\.ml$/, '').samecase("A_"); my $variant = "_ _variant"; my $record = "_ _ record"; @@ -143,9 +148,7 @@ say ""; for $statements -> $statement { say "$statement" } -say "type ('a,'err) monad = ('a) Simple_utils.Trace.result;;"; -say "let (>>?) v f = Simple_utils.Trace.bind f v;;"; -say "let return v = Simple_utils.Trace.ok v;;"; +say "open Adt_generator.Common;;"; say "open $moduleName;;"; say ""; @@ -182,47 +185,37 @@ say ";;"; say ""; for $adts.list -> $t { - say "type ('state, 'err) continue_fold_map__$t = \{"; + say "type ('state, 'err) _continue_fold_map__$t = \{"; say " node__$t : 'state -> $t -> ('state * $t , 'err) monad ;"; for $t.list -> $c { say " $t__$c : 'state -> {$c || 'unit'} -> ('state * {$c || 'unit'} , 'err) monad ;" } say ' };;'; } -say "type ('state , 'err) continue_fold_map = \{"; +say "type ('state , 'err) _continue_fold_map__$moduleName = \{"; for $adts.list -> $t { - say " $t : ('state , 'err) continue_fold_map__$t ;"; + say " $t : ('state , 'err) _continue_fold_map__$t ;"; } say ' };;'; say ""; for $adts.list -> $t -{ say "type ('state , 'err) fold_map_config__$t = \{"; - say " node__$t : 'state -> $t -> ('state, 'err) continue_fold_map -> ('state * $t , 'err) monad ;"; # (*Adt_info.node_instance_info ->*) +{ say "type ('state, 'err) fold_map_config__$t = \{"; + say " node__$t : 'state -> $t -> ('state, 'err) _continue_fold_map__$moduleName -> ('state * $t , 'err) monad ;"; # (*Adt_info.node_instance_info ->*) say " node__$t__pre_state : 'state -> $t -> ('state, 'err) monad ;"; # (*Adt_info.node_instance_info ->*) say " node__$t__post_state : 'state -> $t -> $t -> ('state, 'err) monad ;"; # (*Adt_info.node_instance_info ->*) for $t.list -> $c - { say " $t__$c : 'state -> {$c || 'unit'} -> ('state, 'err) continue_fold_map -> ('state * {$c || 'unit'} , 'err) monad ;"; # (*Adt_info.ctor_or_field_instance_info ->*) + { say " $t__$c : 'state -> {$c || 'unit'} -> ('state, 'err) _continue_fold_map__$moduleName -> ('state * {$c || 'unit'} , 'err) monad ;"; # (*Adt_info.ctor_or_field_instance_info ->*) } say '};;' } -say "type ('state, 'err) fold_map_config ="; +say "type ('state, 'err) fold_map_config__$moduleName ="; say ' {'; for $adts.list -> $t { say " $t : ('state, 'err) fold_map_config__$t;" } say ' };;'; -say ""; -say "module StringMap = Map.Make(String);;"; -say "(* generic folds for nodes *)"; -say "type 'state generic_continue_fold_node = \{"; -say " continue : 'state -> 'state ;"; -say " (* generic folds for each field *)"; -say " continue_ctors_or_fields : ('state -> 'state) StringMap.t ;"; -say '};;'; -say "(* map from node names to their generic folds *)"; -say "type 'state generic_continue_fold = ('state generic_continue_fold_node) StringMap.t;;"; -say ""; +say "include Adt_generator.Generic.BlahBluh"; say "type ('state , 'adt_info_node_instance_info) _fold_config ="; say ' {'; say " generic : 'state -> 'adt_info_node_instance_info -> 'state;"; @@ -372,23 +365,23 @@ for $adts.list -> $t say ""; say "type ('state, 'err) mk_continue_fold_map = \{"; -say " fn : ('state,'err) mk_continue_fold_map -> ('state, 'err) fold_map_config -> ('state , 'err) continue_fold_map"; +say " fn : ('state, 'err) mk_continue_fold_map -> ('state, 'err) fold_map_config__$moduleName -> ('state, 'err) _continue_fold_map__$moduleName"; say '};;'; # fold_map functions say ""; for $adts.list -> $t -{ say "let _fold_map__$t : type qstate err . (qstate,err) mk_continue_fold_map -> (qstate,err) fold_map_config -> qstate -> $t -> (qstate * $t, err) monad = fun mk_continue_fold_map visitor state x ->"; - say " let continue_fold_map : (qstate,err) continue_fold_map = mk_continue_fold_map.fn mk_continue_fold_map visitor in"; +{ say "let _fold_map__$t : type qstate err . (qstate,err) mk_continue_fold_map -> (qstate,err) fold_map_config__$moduleName -> qstate -> $t -> (qstate * $t, err) monad = fun mk_continue_fold_map visitor state x ->"; + say " let continue_fold_map : (qstate,err) _continue_fold_map__$moduleName = mk_continue_fold_map.fn mk_continue_fold_map visitor in"; say " visitor.$t.node__$t__pre_state state x >>? fun state ->"; # (*(fun () -> whole_adt_info, info__$t)*) say " visitor.$t.node__$t state x continue_fold_map >>? fun (state, new_x) ->"; # (*(fun () -> whole_adt_info, info__$t)*) say " visitor.$t.node__$t__post_state state x new_x >>? fun state ->"; # (*(fun () -> whole_adt_info, info__$t)*) say " return (state, new_x);;"; say ""; for $t.list -> $c - { say "let _fold_map__$t__$c : type qstate err . (qstate,err) mk_continue_fold_map -> (qstate,err) fold_map_config -> qstate -> { $c || 'unit' } -> (qstate * { $c || 'unit' }, err) monad = fun mk_continue_fold_map visitor state x ->"; - say " let continue_fold_map : (qstate,err) continue_fold_map = mk_continue_fold_map.fn mk_continue_fold_map visitor in"; + { say "let _fold_map__$t__$c : type qstate err . (qstate,err) mk_continue_fold_map -> (qstate,err) fold_map_config__$moduleName -> qstate -> { $c || 'unit' } -> (qstate * { $c || 'unit' }, err) monad = fun mk_continue_fold_map visitor state x ->"; + say " let continue_fold_map : (qstate,err) _continue_fold_map__$moduleName = mk_continue_fold_map.fn mk_continue_fold_map visitor in"; say " visitor.$t.$t__$c state x continue_fold_map;;"; # (*(fun () -> whole_adt_info, info__$t, info__$t__$c)*) say ""; } } @@ -410,16 +403,16 @@ say ""; # fold_map functions : tying the knot say ""; for $adts.list -> $t -{ say "let fold_map__$t : type qstate err . (qstate,err) fold_map_config -> qstate -> $t -> (qstate * $t,err) monad ="; +{ say "let fold_map__$t : type qstate err . (qstate,err) fold_map_config__$moduleName -> qstate -> $t -> (qstate * $t,err) monad ="; say " fun visitor state x -> _fold_map__$t mk_continue_fold_map visitor state x;;"; for $t.list -> $c - { say "let fold_map__$t__$c : type qstate err . (qstate,err) fold_map_config -> qstate -> { $c || 'unit' } -> (qstate * { $c || 'unit' },err) monad ="; + { say "let fold_map__$t__$c : type qstate err . (qstate,err) fold_map_config__$moduleName -> qstate -> { $c || 'unit' } -> (qstate * { $c || 'unit' },err) monad ="; say " fun visitor state x -> _fold_map__$t__$c mk_continue_fold_map visitor state x;;"; } } for $adts.list -> $t { - say "let no_op_node__$t : type state . state -> $t -> (state,_) continue_fold_map -> (state * $t,_) monad ="; + say "let no_op_node__$t : type state . state -> $t -> (state,_) _continue_fold_map__$moduleName -> (state * $t,_) monad ="; say " fun state v continue ->"; # (*_info*) say " match v with"; if ($t eq $variant) { @@ -460,15 +453,15 @@ for $adts.list -> $t say ") ;"; } say ' }' } -say "let no_op : type state . (state,_) fold_map_config = \{"; +say "let no_op : type state . (state,_) fold_map_config__$moduleName = \{"; for $adts.list -> $t { say " $t = no_op__$t;" } say '};;'; say ""; for $adts.list -> $t -{ say "let with__$t : _ -> _ fold_map_config -> _ fold_map_config = (fun node__$t op -> \{ op with $t = \{ op.$t with node__$t \} \});;"; - say "let with__$t__pre_state : _ -> _ fold_map_config -> _ fold_map_config = (fun node__$t__pre_state op -> \{ op with $t = \{ op.$t with node__$t__pre_state \} \});;"; - say "let with__$t__post_state : _ -> _ fold_map_config -> _ fold_map_config = (fun node__$t__post_state op -> \{ op with $t = \{ op.$t with node__$t__post_state \} \});;"; +{ say "let with__$t : _ -> _ fold_map_config__$moduleName -> _ fold_map_config__$moduleName = (fun node__$t op -> \{ op with $t = \{ op.$t with node__$t \} \});;"; + say "let with__$t__pre_state : _ -> _ fold_map_config__$moduleName -> _ fold_map_config__$moduleName = (fun node__$t__pre_state op -> \{ op with $t = \{ op.$t with node__$t__pre_state \} \});;"; + say "let with__$t__post_state : _ -> _ fold_map_config__$moduleName -> _ fold_map_config__$moduleName = (fun node__$t__post_state op -> \{ op with $t = \{ op.$t with node__$t__post_state \} \});;"; for $t.list -> $c - { say "let with__$t__$c : _ -> _ fold_map_config -> _ fold_map_config = (fun $t__$c op -> \{ op with $t = \{ op.$t with $t__$c \} \});;"; } } + { say "let with__$t__$c : _ -> _ fold_map_config__$moduleName -> _ fold_map_config__$moduleName = (fun $t__$c op -> \{ op with $t = \{ op.$t with $t__$c \} \});;"; } } diff --git a/src/stages/adt_generator/generic.ml b/src/stages/adt_generator/generic.ml index c4f28821a..c48ca1ac1 100644 --- a/src/stages/adt_generator/generic.ml +++ b/src/stages/adt_generator/generic.ml @@ -1,3 +1,15 @@ +module BlahBluh = struct +module StringMap = Map.Make(String);; +(* generic folds for nodes *) +type 'state generic_continue_fold_node = { + continue : 'state -> 'state ; + (* generic folds for each field *) + continue_ctors_or_fields : ('state -> 'state) StringMap.t ; +};; +(* map from node names to their generic folds *) +type 'state generic_continue_fold = ('state generic_continue_fold_node) StringMap.t;; +end + module Adt_info (M : sig type ('state , 'adt_info_node_instance_info) fold_config end) = struct type kind = | Record diff --git a/src/stages/typesystem/core.ml b/src/stages/typesystem/core.ml index 1dab4fb13..07f78184e 100644 --- a/src/stages/typesystem/core.ml +++ b/src/stages/typesystem/core.ml @@ -1,3 +1,27 @@ +type unionfind = Ast_typed.unionfind +type constant_tag = Ast_typed.constant_tag +type accessor = Ast_typed.label +type type_value = Ast_typed.type_value +type p_constraints = Ast_typed.p_constraints +type p_forall = Ast_typed.p_forall +type simple_c_constructor = Ast_typed.simple_c_constructor +type simple_c_constant = Ast_typed.simple_c_constant +type c_const = Ast_typed.c_const +type c_equation = Ast_typed.c_equation +type c_typeclass = Ast_typed.c_typeclass +type c_access_label = Ast_typed.c_access_label +type type_constraint = Ast_typed.type_constraint +type typeclass = Ast_typed.typeclass +type 'a typeVariableMap = 'a Ast_typed.typeVariableMap +type structured_dbs = Ast_typed.structured_dbs +type constraints = Ast_typed.constraints +type c_constructor_simpl = Ast_typed.c_constructor_simpl +type c_const_e = Ast_typed.c_const_e +type c_equation_e = Ast_typed.c_equation_e +type c_typeclass_simpl = Ast_typed.c_typeclass_simpl +type c_poly_simpl = Ast_typed.c_poly_simpl +type type_constraint_simpl = Ast_typed.type_constraint_simpl +type state = Ast_typed.typer_state type type_variable = Ast_typed.type_variable type type_expression = Ast_typed.type_expression @@ -6,68 +30,9 @@ type type_expression = Ast_typed.type_expression let fresh_type_variable : ?name:string -> unit -> type_variable = Var.fresh - -(* add information on the type or the kind for operator*) -type constant_tag = - | C_arrow (* * -> * -> * *) (* isn't this wrong*) - | C_option (* * -> * *) - | C_record (* ( label , * ) … -> * *) - | C_variant (* ( label , * ) … -> * *) - | C_map (* * -> * -> * *) - | C_big_map (* * -> * -> * *) - | C_list (* * -> * *) - | C_set (* * -> * *) - | C_unit (* * *) - | C_string (* * *) - | C_nat (* * *) - | C_mutez (* * *) - | C_timestamp (* * *) - | C_int (* * *) - | C_address (* * *) - | C_bytes (* * *) - | C_key_hash (* * *) - | C_key (* * *) - | C_signature (* * *) - | C_operation (* * *) - | C_contract (* * -> * *) - | C_chain_id (* * *) - -type accessor = Ast_typed.label - -(* Weird stuff; please explain *) -type type_value = - | P_forall of p_forall - | P_variable of type_variable (* how a value can be a variable? *) - | P_constant of (constant_tag * type_value list) - | P_apply of (type_value * type_value) - -and p_forall = { - binder : type_variable ; - constraints : type_constraint list ; - body : type_value -} - -(* Different type of constraint *) (* why isn't this a variant ? *) -and simple_c_constructor = (constant_tag * type_variable list) (* non-empty list *) -and simple_c_constant = (constant_tag) (* for type constructors that do not take arguments *) -and c_const = (type_variable * type_value) -and c_equation = (type_value * type_value) -and c_typeclass = (type_value list * typeclass) -and c_access_label = (type_value * accessor * type_variable) - -(*What i was saying just before *) -and type_constraint = - (* | C_assignment of (type_variable * type_pattern) *) - | C_equation of c_equation (* TVA = TVB *) - | C_typeclass of c_typeclass (* TVL ∈ TVLs, for now in extension, later add intensional (rule-based system for inclusion in the typeclass) *) - | C_access_label of c_access_label (* poor man's type-level computation to ensure that TV.label is type_variable *) -(* | … *) - -(* is the first list in case on of the type of the type class as a kind *->*->* ? *) -and typeclass = type_value list list - open Trace -let type_expression'_of_simple_c_constant = function +let type_expression'_of_simple_c_constant : constant_tag * type_expression list -> Ast_typed.type_content result = fun (c, l) -> + match c, l with | C_contract , [x] -> ok @@ Ast_typed.T_operator(TC_contract x) | C_option , [x] -> ok @@ Ast_typed.T_operator(TC_option x) | C_list , [x] -> ok @@ Ast_typed.T_operator(TC_list x) diff --git a/src/stages/typesystem/misc.ml b/src/stages/typesystem/misc.ml index bea3e693e..f060a2810 100644 --- a/src/stages/typesystem/misc.ml +++ b/src/stages/typesystem/misc.ml @@ -223,16 +223,15 @@ module Substitution = struct and type_value ~tv ~substs = let self tv = type_value ~tv ~substs in let (v, expr) = substs in - match tv with + match (tv : type_value) with | P_variable v' when Var.equal v' v -> expr | P_variable _ -> tv - | P_constant (x , lst) -> ( + | P_constant {p_ctor_tag=x ; p_ctor_args=lst} -> ( let lst' = List.map self lst in - P_constant (x , lst') + P_constant {p_ctor_tag=x ; p_ctor_args=lst'} ) - | P_apply ab -> ( - let ab' = pair_map self ab in - P_apply ab' + | P_apply { tf; targ } -> ( + P_apply { tf = self tf ; targ = self targ } ) | P_forall p -> ( let aux c = constraint_ ~c ~substs in @@ -247,18 +246,18 @@ module Substitution = struct and constraint_ ~c ~substs = match c with - | C_equation ab -> ( - let ab' = pair_map (fun tv -> type_value ~tv ~substs) ab in - C_equation ab' + | C_equation { aval; bval } -> ( + let aux tv = type_value ~tv ~substs in + C_equation { aval = aux aval ; bval = aux bval } ) - | C_typeclass (tvs , tc) -> ( - let tvs' = List.map (fun tv -> type_value ~tv ~substs) tvs in - let tc' = typeclass ~tc ~substs in - C_typeclass (tvs' , tc') + | C_typeclass { tc_args; typeclass=tc } -> ( + let tc_args = List.map (fun tv -> type_value ~tv ~substs) tc_args in + let tc = typeclass ~tc ~substs in + C_typeclass {tc_args ; typeclass=tc} ) - | C_access_label (tv , l , v') -> ( - let tv' = type_value ~tv ~substs in - C_access_label (tv' , l , v') + | C_access_label { c_access_label_tval; accessor; c_access_label_tvar } -> ( + let c_access_label_tval = type_value ~tv:c_access_label_tval ~substs in + C_access_label {c_access_label_tval ; accessor ; c_access_label_tvar} ) and typeclass ~tc ~substs = @@ -269,9 +268,9 @@ module Substitution = struct (* Performs beta-reduction at the root of the type *) let eval_beta_root ~(tv : type_value) = match tv with - P_apply (P_forall { binder; constraints; body }, arg) -> - let constraints = List.map (fun c -> constraint_ ~c ~substs:(mk_substs ~v:binder ~expr:arg)) constraints in - (type_value ~tv:body ~substs:(mk_substs ~v:binder ~expr:arg) , constraints) + P_apply {tf = P_forall { binder; constraints; body }; targ} -> + let constraints = List.map (fun c -> constraint_ ~c ~substs:(mk_substs ~v:binder ~expr:targ)) constraints in + (type_value ~tv:body ~substs:(mk_substs ~v:binder ~expr:targ) , constraints) | _ -> (tv , []) end diff --git a/src/stages/typesystem/shorthands.ml b/src/stages/typesystem/shorthands.ml index 08b25ae5b..c01775120 100644 --- a/src/stages/typesystem/shorthands.ml +++ b/src/stages/typesystem/shorthands.ml @@ -1,7 +1,9 @@ +open Ast_typed.Types open Core +open Ast_typed.Misc -let tc type_vars allowed_list = - Core.C_typeclass (type_vars , allowed_list) +let tc type_vars allowed_list : type_constraint = + C_typeclass {tc_args = type_vars ; typeclass = allowed_list} let forall binder f = let () = ignore binder in @@ -45,32 +47,32 @@ let forall2_tc a b f = f a' b' let (=>) tc ty = (tc , ty) -let (-->) arg ret = P_constant (C_arrow , [arg; ret]) -let option t = P_constant (C_option , [t]) -let pair a b = P_constant (C_record , [a; b]) -let sum a b = P_constant (C_variant, [a; b]) -let map k v = P_constant (C_map , [k; v]) -let unit = P_constant (C_unit , []) -let list t = P_constant (C_list , [t]) -let set t = P_constant (C_set , [t]) -let bool = P_variable (Stage_common.Constant.t_bool) -let string = P_constant (C_string , []) -let nat = P_constant (C_nat , []) -let mutez = P_constant (C_mutez , []) -let timestamp = P_constant (C_timestamp , []) -let int = P_constant (C_int , []) -let address = P_constant (C_address , []) -let chain_id = P_constant (C_chain_id , []) -let bytes = P_constant (C_bytes , []) -let key = P_constant (C_key , []) -let key_hash = P_constant (C_key_hash , []) -let signature = P_constant (C_signature , []) -let operation = P_constant (C_operation , []) -let contract t = P_constant (C_contract , [t]) +let (-->) arg ret = p_constant C_arrow [arg; ret] +let option t = p_constant C_option [t] +let pair a b = p_constant C_record [a; b] +let sum a b = p_constant C_variant [a; b] +let map k v = p_constant C_map [k; v] +let unit = p_constant C_unit [] +let list t = p_constant C_list [t] +let set t = p_constant C_set [t] +let bool = P_variable Stage_common.Constant.t_bool +let string = p_constant C_string [] +let nat = p_constant C_nat [] +let mutez = p_constant C_mutez [] +let timestamp = p_constant C_timestamp [] +let int = p_constant C_int [] +let address = p_constant C_address [] +let chain_id = p_constant C_chain_id [] +let bytes = p_constant C_bytes [] +let key = p_constant C_key [] +let key_hash = p_constant C_key_hash [] +let signature = p_constant C_signature [] +let operation = p_constant C_operation [] +let contract t = p_constant C_contract [t] let ( * ) a b = pair a b (* These are used temporarily to de-curry functions that correspond to Michelson operators *) -let tuple0 = P_constant (C_record , []) -let tuple1 a = P_constant (C_record , [a]) -let tuple2 a b = P_constant (C_record , [a; b]) -let tuple3 a b c = P_constant (C_record , [a; b; c]) +let tuple0 = p_constant C_record [] +let tuple1 a = p_constant C_record [a] +let tuple2 a b = p_constant C_record [a; b] +let tuple3 a b c = p_constant C_record [a; b; c] diff --git a/src/test/adt_generator/use_a_fold.ml b/src/test/adt_generator/use_a_fold.ml index c065cabf2..484940341 100644 --- a/src/test/adt_generator/use_a_fold.ml +++ b/src/test/adt_generator/use_a_fold.ml @@ -58,8 +58,8 @@ let () = (* Test that the same fold_map_config can be ascibed with different 'a type arguments *) -let _noi : (int, [> error]) fold_map_config = no_op (* (fun _ -> ()) *) -let _nob : (bool, [> error]) fold_map_config = no_op (* (fun _ -> ()) *) +let _noi : (int, [> error]) fold_map_config__Amodule = no_op (* (fun _ -> ()) *) +let _nob : (bool, [> error]) fold_map_config__Amodule = no_op (* (fun _ -> ()) *) let () = let some_root : root = A [ { a1 = X (A [ { a1 = X (B [ 1 ; 2 ; 3 ]) ; a2 = W () } ]) ; a2 = Z (W ()) } ] in diff --git a/vendors/Red-Black_Trees/PolyMap.ml b/vendors/Red-Black_Trees/PolyMap.ml index ee485ec40..78ab1738c 100644 --- a/vendors/Red-Black_Trees/PolyMap.ml +++ b/vendors/Red-Black_Trees/PolyMap.ml @@ -11,7 +11,7 @@ type ('key, 'value) map = ('key, 'value) t let create ~cmp = {tree = RB.empty; cmp} -let empty = {tree = RB.empty; cmp=Pervasives.compare} +let empty map = {tree = RB.empty; cmp=map.cmp} let is_empty map = RB.is_empty map.tree @@ -19,6 +19,10 @@ let add key value map = let cmp (k1,_) (k2,_) = map.cmp k1 k2 in {map with tree = RB.add ~cmp RB.New (key, value) map.tree} +let remove key map = + let cmp k1 (k2,_) = map.cmp k1 k2 in + {map with tree = RB.remove ~cmp key map.tree} + exception Not_found let find key map = @@ -29,6 +33,11 @@ let find key map = let find_opt key map = try Some (find key map) with Not_found -> None +let update key updater map = + match updater (find_opt key map) with + | None -> failwith "TODO: RedBlackTrees: remove not implemented" (* TODO: remove key *) + | Some v -> add key v map + let bindings map = RB.fold_dec (fun ~elt ~acc -> elt::acc) ~init:[] map.tree diff --git a/vendors/Red-Black_Trees/PolyMap.mli b/vendors/Red-Black_Trees/PolyMap.mli index 7aafb8ae0..01e0d1468 100644 --- a/vendors/Red-Black_Trees/PolyMap.mli +++ b/vendors/Red-Black_Trees/PolyMap.mli @@ -20,7 +20,7 @@ type ('key, 'value) map = ('key, 'value) t val create : cmp:('key -> 'key -> int) -> ('key, 'value) t -val empty : ('key, 'value) t +val empty : ('key, 'value) t -> ('key, 'new_value) t (* Emptiness *) @@ -33,6 +33,11 @@ val is_empty : ('key, 'value) t -> bool val add : 'key -> 'value -> ('key, 'value) t -> ('key, 'value) t +(* The value of the call [add key value map] is a map containing all + the bindings of the map [map], except for the binding of [key]. *) + +val remove : 'key -> ('key, 'value) t -> ('key, 'value) t + (* The value of the call [find key map] is the value associated to the [key] in the map [map]. If [key] is not bound in [map], the exception [Not_found] is raised. *) @@ -47,6 +52,17 @@ val find : 'key -> ('key, 'value) t -> 'value val find_opt : 'key -> ('key, 'value) t -> 'value option +(* The value of the call [update key f map] is a map containing all + the bindings of the map [map], extended by the binding of [key] to + the value returned by [f], when [f maybe_value] returns + [Some value]. On the other hand, when [f maybe_value] returns + [None], the existing binding for [key] in [map] is removed from the + map, if there is one. The argument [maybe_value] passed to [f] is + [Some value] if the key [key] is bound to [value] in the map [map], + and [None] otherwise. *) + +val update : 'key -> ('value option -> 'value option) -> ('key, 'value) map -> ('key, 'value) map + (* The value of the call [bindings map] is the association list containing the bindings of the map [map], sorted by increasing keys (with respect to the total comparison function used to create the diff --git a/vendors/Red-Black_Trees/PolySet.ml b/vendors/Red-Black_Trees/PolySet.ml index 7e60fc3bd..7bbc3d628 100644 --- a/vendors/Red-Black_Trees/PolySet.ml +++ b/vendors/Red-Black_Trees/PolySet.ml @@ -11,7 +11,7 @@ type 'elt set = 'elt t let create ~cmp = {tree = RB.empty; cmp} -let empty = {tree = RB.empty; cmp=Pervasives.compare} +let empty set = {tree = RB.empty; cmp=set.cmp} let is_empty set = RB.is_empty set.tree diff --git a/vendors/Red-Black_Trees/PolySet.mli b/vendors/Red-Black_Trees/PolySet.mli index 42f85a529..b76ebfd97 100644 --- a/vendors/Red-Black_Trees/PolySet.mli +++ b/vendors/Red-Black_Trees/PolySet.mli @@ -19,7 +19,7 @@ type 'elt set = 'elt t val create : cmp:('elt -> 'elt -> int) -> 'elt t -val empty : 'elt t +val empty : 'elt t -> 'elt t (* Emptiness *) diff --git a/vendors/Red-Black_Trees/RedBlack.ml b/vendors/Red-Black_Trees/RedBlack.ml index 50bb9659f..4241363dc 100644 --- a/vendors/Red-Black_Trees/RedBlack.ml +++ b/vendors/Red-Black_Trees/RedBlack.ml @@ -50,6 +50,32 @@ let add ~cmp choice elt tree = in try blacken (insert tree) with Physical_equality -> tree +let remove : type a b . cmp:(a -> b -> int) -> a -> b t -> b t = fun ~cmp elt tree -> + (* TODO: this leaves the tree not properly balanced. *) + let rec bst_shift_up : b t -> b t = function + | Ext -> failwith "unknown error" + | Int (colour, left, root, right) -> + ( + ignore root; (* we delete the root *) + match left, right with + | Ext, Ext -> Ext + | Ext, Int (_rcolour, _rleft, rroot, _rright) -> + let new_right = bst_shift_up right in + Int (colour, Ext, rroot, new_right) + | Int (_lcolour, _lleft, lroot, _lright), _ -> + let new_left = bst_shift_up left in + Int (colour, new_left, lroot, right) + ) in + let rec bst_delete : a -> b t -> b t = fun elt -> function + | Ext -> failwith "remove in red-black tree: element not found" + | Int (colour, left, root, right) as current -> + let c = cmp elt root in + if c = 0 then bst_shift_up current + else if c < 0 then Int (colour, bst_delete elt left, root, right) + else Int (colour, left, root, bst_delete elt right) + in + bst_delete elt tree + exception Not_found let rec find ~cmp elt = function diff --git a/vendors/Red-Black_Trees/RedBlack.mli b/vendors/Red-Black_Trees/RedBlack.mli index 65a45230c..9642da8e6 100644 --- a/vendors/Red-Black_Trees/RedBlack.mli +++ b/vendors/Red-Black_Trees/RedBlack.mli @@ -26,6 +26,15 @@ type choice = Old | New val add: cmp:('a -> 'a -> int) -> choice -> 'a -> 'a t -> 'a t +(* The value of the call [remove ~cmp x t] is a red-black tree + containing the same elements as [t] with the exception of the + element identified by [x]. The type of [x] can be different from + that of the elements of the tree, for example if the tree is used to + implement a map, x would be a [key], whereas the elements of the tree + would be [key, value] pairs. *) + +val remove: cmp:('a -> 'b -> int) -> 'a -> 'b t -> 'b t + (* The value of the call [find ~cmp x t] is the element [y] belonging to a node of the tree [t], such that [cmp x y = true]. If none, the exception [Not_found] is raised. *) diff --git a/vendors/UnionFind/Poly2.ml b/vendors/UnionFind/Poly2.ml index dd3660b14..047bd9934 100644 --- a/vendors/UnionFind/Poly2.ml +++ b/vendors/UnionFind/Poly2.ml @@ -43,6 +43,7 @@ let map_empty (compare : 'item -> 'item -> int) : ('item, 'value) map = RedBlack let map_find : 'item 'value . 'item -> ('item, 'value) map -> 'value = RedBlackTrees.PolyMap.find let map_iter : 'item 'value . ('item -> 'value -> unit) -> ('item, 'value) map -> unit = RedBlackTrees.PolyMap.iter let map_add : 'item 'value . 'item -> 'value -> ('item, 'value) map -> ('item, 'value) map = RedBlackTrees.PolyMap.add +let map_sorted_keys : 'item 'value . ('item, 'value) map -> 'item list = fun m -> List.map fst @@ RedBlackTrees.PolyMap.bindings m (** The type [partition] implements a partition of classes of equivalent items by means of a map from items to nodes of type @@ -76,17 +77,20 @@ let is_equiv (i: 'item) (j: 'item) (p: 'item partition) : bool = try equal p.compare (repr i p) (repr j p) with Not_found -> false -let get_or_set (i: 'item) (p: 'item partition) = +let get_or_set_h (i: 'item) (p: 'item partition) = try seek i p, p with Not_found -> let n = i,0 in (n, root n p) +let get_or_set (i: 'item) (p: 'item partition) = + let (i, _h), p = get_or_set_h i p in (i, p) + let mem i p = try Some (repr i p) with Not_found -> None let repr i p = try repr i p with Not_found -> i let equiv (i: 'item) (j: 'item) (p: 'item partition) : 'item partition = - let (ri,hi as ni), p = get_or_set i p in - let (rj,hj as nj), p = get_or_set j p in + let (ri,hi as ni), p = get_or_set_h i p in + let (rj,hj as nj), p = get_or_set_h j p in if equal p.compare ri rj then p else if hi > hj @@ -104,8 +108,8 @@ let equiv (i: 'item) (j: 'item) (p: 'item partition) : 'item partition = applied (which, without the constraint above, would yield a height-balanced new tree). *) let alias (i: 'item) (j: 'item) (p: 'item partition) : 'item partition = - let (ri,hi as ni), p = get_or_set i p in - let (rj,hj as nj), p = get_or_set j p in + let (ri,hi as ni), p = get_or_set_h i p in + let (rj,hj as nj), p = get_or_set_h j p in if equal p.compare ri rj then p else if hi = hj || equal p.compare ri i @@ -113,10 +117,15 @@ let alias (i: 'item) (j: 'item) (p: 'item partition) : 'item partition = else if hi < hj then link ni rj p else link nj ri p +(** {1 iteration over the elements} *) + +let elements : 'item . 'item partition -> 'item list = + fun { to_string=_; compare=_; map } -> + map_sorted_keys map + (** {1 Printing} *) -let print (p: 'item partition) = - let buffer = Buffer.create 80 in +let print ppf (p: 'item partition) = let print i node = let hi, hj, j = match node with @@ -124,8 +133,8 @@ let print (p: 'item partition) = | Link (j,hi) -> match map_find j p.map with Root hj | Link (_,hj) -> hi,hj,j in - let link = - Printf.sprintf "%s,%d -> %s,%d\n" + let () = + Format.fprintf ppf "%s,%d -> %s,%d\n" (p.to_string i) hi (p.to_string j) hj - in Buffer.add_string buffer link - in map_iter print p.map; buffer + in () + in map_iter print p.map diff --git a/vendors/UnionFind/Poly2.mli b/vendors/UnionFind/Poly2.mli new file mode 100644 index 000000000..f6db36a85 --- /dev/null +++ b/vendors/UnionFind/Poly2.mli @@ -0,0 +1,63 @@ +(** This module offers the abstract data type of a partition of + classes of equivalent items (Union & Find). *) + +(** The items are of type 't, they have to obey a total order, + but also they must be printable to ease debugging. *) + +type 'item partition +type 'item t = 'item partition + +(** {1 Creation} *) + +(** The value [empty] is an empty partition. *) +val empty : ('a -> string) -> ('a -> 'a -> int) -> 'a partition + +(** The value of [equiv i j p] is the partition [p] extended with + the equivalence of items [i] and [j]. If both [i] and [j] are + already known to be equivalent, then [equiv i j p == p]. *) +val equiv : 'item -> 'item -> 'item t -> 'item partition + +(** The value of [alias i j p] is the partition [p] extended with + the fact that item [i] is an alias of item [j]. This is the + same as [equiv i j p], except that it is guaranteed that the + item [i] is not the representative of its equivalence class in + [alias i j p]. *) +val alias : 'item -> 'item -> 'item partition -> 'item partition + +(** {1 Projection} *) + +(** The value of the call [repr i p] is [j] if the item [i] is in + the partition [p] and its representative is [j]. If [i] is not + in [p], then the value is [i]. *) +val repr : 'item -> 'item partition -> 'item + +(** The value of the call [get_or_set i p] is [j, p] if the item [i] is + in the partition [p] and its representative is [j]. If [i] is not + in [p], then the value is [i, p'], where p' is the partition [p] + extended with the fact that item [i] is a singleton partition. *) + +val get_or_set : 'item -> 'item t -> 'item * 'item t + +(** The value of the call [mem i p] is [Some j] if the item [i] is + in the partition [p] and its representative is [j]. If [i] is + not in [p], then the value is [None]. *) +val mem : 'item -> 'item partition -> 'item option + +(** The value of the call [elements p] is a list of the elements of p, + ordered in ascending order *) +val elements : 'item partition -> 'item list + +(** The call [print p] is a value of type [Buffer.t] containing + strings denoting the partition [p], based on + [Ord.to_string]. *) +val print : Format.formatter -> 'item partition -> unit + +(** {1 Predicates} *) + +(** The value of [is_equiv i j p] is [true] if, and only if, the + items [i] and [j] belong to the same equivalence class in the + partition [p], that is, [i] and [j] have the same + representative. In particular, if either [i] or [j] do not + belong to [p], the value of [is_equiv i j p] is [false]. See + [mem] above. *) +val is_equiv : 'item -> 'item -> 'item partition -> bool