diff --git a/src/passes/4-typer/dune b/src/passes/4-typer/dune index 0ee58cc43..ec35ab2ce 100644 --- a/src/passes/4-typer/dune +++ b/src/passes/4-typer/dune @@ -7,6 +7,7 @@ ast_simplified ast_typed operators + union_find ) (preprocess (pps ppx_let) diff --git a/src/passes/4-typer/solver.ml b/src/passes/4-typer/solver.ml new file mode 100644 index 000000000..890d067e3 --- /dev/null +++ b/src/passes/4-typer/solver.ml @@ -0,0 +1,716 @@ +open Trace + +module Core = Typesystem.Core + +module Wrap = struct + module I = Ast_simplified + module O = Core + + type constraints = O.type_constraint list + + (* let add_type state t = *) + (* let constraints = Wrap.variable type_name t in *) + (* let%bind state' = aggregate_constraints state constraints in *) + (* ok state' in *) + (* let return_add_type ?(state = state) expr t = *) + (* let%bind state' = add_type state t in *) + (* return expr state' in *) + + let rec type_expression_to_type_value : I.type_expression -> O.type_value = fun te -> + match te with + | T_tuple types -> + P_constant (C_tuple, List.map type_expression_to_type_value types) + | T_sum kvmap -> + P_constant (C_variant, Map.String.to_list @@ Map.String.map type_expression_to_type_value kvmap) + | T_record kvmap -> + P_constant (C_record, Map.String.to_list @@ Map.String.map type_expression_to_type_value kvmap) + | T_function (arg , ret) -> + P_constant (C_arrow, List.map type_expression_to_type_value [ arg ; ret ]) + | T_variable type_name -> P_variable type_name + | T_constant (type_name , args) -> + let csttag = Core.(match type_name with + | "arrow" -> C_arrow + | "option" -> C_option + | "tuple" -> C_tuple + | "map" -> C_map + | "list" -> C_list + | "set" -> C_set + | "unit" -> C_unit + | "bool" -> C_bool + | "string" -> C_string + | _ -> failwith "TODO") + in + P_constant (csttag, List.map type_expression_to_type_value args) + + (** TODO *) + let type_declaration : I.declaration -> constraints = fun td -> + match td with + | Declaration_type (name , te) -> + let pattern = type_expression_to_type_value te in + [C_equation (P_variable (name) , pattern)] (* TODO: this looks wrong. If this is a type declaration, it should not set any constraints. *) + | Declaration_constant (name, te, _) ->( + match te with + | Some (exp) -> + let pattern = type_expression_to_type_value exp in + [C_equation (P_variable (name) , pattern)] (* TODO: this looks wrong. If this is a type declaration, it should not set any constraints. *) + | None -> + (** TODO *) + [] + ) + + (* TODO: this should be renamed to failwith_ *) + let failwith : unit -> (constraints * O.type_variable) = fun () -> + let type_name = Core.fresh_type_variable () in + [] , type_name + + let variable : I.name -> I.type_expression -> (constraints * O.type_variable) = fun _name expr -> + let pattern = type_expression_to_type_value expr in + let type_name = Core.fresh_type_variable () in + [C_equation (P_variable (type_name) , pattern)] , type_name + + let literal : I.type_expression -> (constraints * O.type_variable) = fun t -> + let pattern = type_expression_to_type_value t in + let type_name = Core.fresh_type_variable () in + [C_equation (P_variable (type_name) , pattern)] , type_name + + (* + let literal_bool : unit -> (constraints * O.type_variable) = fun () -> + let pattern = type_expression_to_type_value I.t_bool in + let type_name = Core.fresh_type_variable () in + [C_equation (P_variable (type_name) , pattern)] , type_name + + let literal_string : unit -> (constraints * O.type_variable) = fun () -> + let pattern = type_expression_to_type_value I.t_string in + let type_name = Core.fresh_type_variable () in + [C_equation (P_variable (type_name) , pattern)] , type_name + *) + + let tuple : I.type_expression list -> (constraints * O.type_variable) = fun tys -> + let patterns = List.map type_expression_to_type_value tys in + let pattern = O.(P_constant (C_tuple , patterns)) in + let type_name = Core.fresh_type_variable () in + [C_equation (P_variable (type_name) , pattern)] , type_name + + (* let t_tuple = ('label:int, 'v) … -> record ('label : 'v) … *) + (* let t_constructor = ('label:string, 'v) -> variant ('label : 'v) *) + (* let t_record = ('label:string, 'v) … -> record ('label : 'v) … with independent choices for each 'label and 'v *) + (* let t_variable = t_of_var_in_env *) + (* let t_access_int = record ('label:int , 'v) … -> 'label:int -> 'v *) + (* let t_access_string = record ('label:string , 'v) … -> 'label:string -> 'v *) + + module Prim_types = struct + open Typesystem.Shorthands + + let t_cons = forall "v" @@ fun v -> v --> list v --> list v (* was: list *) + let t_setcons = forall "v" @@ fun v -> v --> set v --> set v (* was: set *) + let t_mapcons = forall2 "k" "v" @@ fun k v -> (k * v) --> map k v --> map k v (* was: map *) + let t_failwith = forall "a" @@ fun a -> a + (* let t_literal_t = t *) + let t_literal_bool = bool + let t_literal_string = string + let t_access_map = forall2 "k" "v" @@ fun k v -> map k v --> k --> v + let t_application = forall2 "a" "b" @@ fun a b -> (a --> b) --> a --> b + let t_look_up = forall2 "ind" "v" @@ fun ind v -> map ind v --> ind --> option v + let t_sequence = forall "b" @@ fun b -> unit --> b --> b + let t_loop = bool --> unit --> unit + end + + (* TODO: I think we should take an I.expression for the base+label *) + let access_label ~base ~label : (constraints * O.type_variable) = + let base' = type_expression_to_type_value base in + let expr_type = Core.fresh_type_variable () in + [O.C_access_label (base' , label , expr_type)] , expr_type + + let access_int ~base ~index = access_label ~base ~label:(L_int index) + let access_string ~base ~property = access_label ~base ~label:(L_string property) + + let access_map : base:I.type_expression -> key:I.type_expression -> (constraints * O.type_variable) = + let mk_map_type key_type element_type = + O.P_constant O.(C_map , [P_variable element_type; P_variable key_type]) in + fun ~base ~key -> + let key_type = Core.fresh_type_variable () in + let element_type = Core.fresh_type_variable () in + let base' = type_expression_to_type_value base in + let key' = type_expression_to_type_value key in + let base_expected = mk_map_type key_type element_type in + let expr_type = Core.fresh_type_variable () in + O.[C_equation (base' , base_expected); + C_equation (key' , P_variable key_type); + C_equation (P_variable expr_type , P_variable element_type)] , expr_type + + let constructor + : I.type_expression -> I.type_expression -> I.type_expression -> (constraints * O.type_variable) + = fun t_arg c_arg sum -> + let t_arg = type_expression_to_type_value t_arg in + let c_arg = type_expression_to_type_value c_arg in + let sum = type_expression_to_type_value sum in + let whole_expr = Core.fresh_type_variable () in + [ + C_equation (P_variable (whole_expr) , sum) ; + C_equation (t_arg , c_arg) + ] , whole_expr + + let record : I.type_expression I.type_name_map -> (constraints * O.type_variable) = fun fields -> + let record_type = type_expression_to_type_value (I.t_record fields) in + let whole_expr = Core.fresh_type_variable () in + [C_equation (P_variable whole_expr , record_type)] , whole_expr + + let collection : O.constant_tag -> I.type_expression list -> (constraints * O.type_variable) = + fun ctor element_tys -> + let elttype = O.P_variable (Core.fresh_type_variable ()) in + let aux elt = + let elt' = type_expression_to_type_value elt + in O.C_equation (elttype , elt') in + let equations = List.map aux element_tys in + let whole_expr = Core.fresh_type_variable () in + O.[ + C_equation (P_variable whole_expr , O.P_constant (ctor , [elttype])) + ] @ equations , whole_expr + + let list = collection O.C_list + let set = collection O.C_set + + let map : (I.type_expression * I.type_expression) list -> (constraints * O.type_variable) = + fun kv_tys -> + let k_type = O.P_variable (Core.fresh_type_variable ()) in + let v_type = O.P_variable (Core.fresh_type_variable ()) in + let aux_k (k , _v) = + let k' = type_expression_to_type_value k in + O.C_equation (k_type , k') in + let aux_v (_k , v) = + let v' = type_expression_to_type_value v in + O.C_equation (v_type , v') in + let equations_k = List.map aux_k kv_tys in + let equations_v = List.map aux_v kv_tys in + let whole_expr = Core.fresh_type_variable () in + O.[ + C_equation (P_variable whole_expr , O.P_constant (C_map , [k_type ; v_type])) + ] @ equations_k @ equations_v , whole_expr + + let application : I.type_expression -> I.type_expression -> (constraints * O.type_variable) = + fun f arg -> + let whole_expr = Core.fresh_type_variable () in + let f' = type_expression_to_type_value f in + let arg' = type_expression_to_type_value arg in + O.[ + C_equation (f' , P_constant (C_arrow , [arg' ; P_variable whole_expr])) + ] , whole_expr + + let look_up : I.type_expression -> I.type_expression -> (constraints * O.type_variable) = + fun ds ind -> + let ds' = type_expression_to_type_value ds in + let ind' = type_expression_to_type_value ind in + let whole_expr = Core.fresh_type_variable () in + let v = Core.fresh_type_variable () in + O.[ + C_equation (ds' , P_constant (C_map, [ind' ; P_variable v])) ; + C_equation (P_variable whole_expr , P_constant (C_option , [P_variable v])) + ] , whole_expr + + let sequence : I.type_expression -> I.type_expression -> (constraints * O.type_variable) = + fun a b -> + let a' = type_expression_to_type_value a in + let b' = type_expression_to_type_value b in + let whole_expr = Core.fresh_type_variable () in + O.[ + C_equation (a' , P_constant (C_unit , [])) ; + C_equation (b' , P_variable whole_expr) + ] , whole_expr + + let loop : I.type_expression -> I.type_expression -> (constraints * O.type_variable) = + fun expr body -> + let expr' = type_expression_to_type_value expr in + let body' = type_expression_to_type_value body in + let whole_expr = Core.fresh_type_variable () in + O.[ + C_equation (expr' , P_constant (C_bool , [])) ; + C_equation (body' , P_constant (C_unit , [])) ; + C_equation (P_variable whole_expr , P_constant (C_unit , [])) + ] , whole_expr + + let let_in : I.type_expression -> I.type_expression option -> I.type_expression -> (constraints * O.type_variable) = + fun rhs rhs_tv_opt result -> + let rhs' = type_expression_to_type_value rhs in + let result' = type_expression_to_type_value result in + let rhs_tv_opt' = match rhs_tv_opt with + None -> [] + | Some annot -> O.[C_equation (rhs' , type_expression_to_type_value annot)] in + let whole_expr = Core.fresh_type_variable () in + O.[ + C_equation (result' , P_variable whole_expr) + ] @ rhs_tv_opt', whole_expr + + let assign : I.type_expression -> I.type_expression -> (constraints * O.type_variable) = + fun v e -> + let v' = type_expression_to_type_value v in + let e' = type_expression_to_type_value e in + let whole_expr = Core.fresh_type_variable () in + O.[ + C_equation (v' , e') ; + C_equation (P_variable whole_expr , P_constant (C_unit , [])) + ] , whole_expr + + let annotation : I.type_expression -> I.type_expression -> (constraints * O.type_variable) = + fun e annot -> + let e' = type_expression_to_type_value e in + let annot' = type_expression_to_type_value annot in + let whole_expr = Core.fresh_type_variable () in + O.[ + C_equation (e' , annot') ; + C_equation (e' , P_variable whole_expr) + ] , whole_expr + + let matching : I.type_expression list -> (constraints * O.type_variable) = + fun es -> + let whole_expr = Core.fresh_type_variable () in + let type_values = (List.map type_expression_to_type_value es) in + let cs = List.map (fun e -> O.C_equation (P_variable whole_expr , e)) type_values + in cs, whole_expr + + let fresh_binder () = + Core.fresh_type_variable () + + let lambda + : I.type_expression -> + I.type_expression option -> + I.type_expression option -> + (constraints * O.type_variable) = + fun fresh arg body -> + let whole_expr = Core.fresh_type_variable () in + let unification_arg = Core.fresh_type_variable () in + let unification_body = Core.fresh_type_variable () in + let arg' = match arg with + None -> [] + | Some arg -> O.[C_equation (P_variable unification_arg , type_expression_to_type_value arg)] in + let body' = match body with + None -> [] + | Some body -> O.[C_equation (P_variable unification_body , type_expression_to_type_value body)] + in O.[ + C_equation (type_expression_to_type_value fresh , P_variable unification_arg) ; + C_equation (P_variable whole_expr , + P_constant (C_arrow , [P_variable unification_arg ; + P_variable unification_body])) + ] @ arg' @ body' , whole_expr + +end + +(* begin unionfind *) + +module TV = +struct + type t = Core.type_variable + let compare = String.compare + let to_string = (fun s -> s) +end + +module UF = Union_find.Partition0.Make(TV) + +type unionfind = UF.t + +let empty = UF.empty (* DEMO *) +let representative_toto = UF.repr "toto" empty (* DEMO *) +let merge x y = UF.equiv x y (* DEMO *) + +(* end unionfind *) + +(* representant for an equivalence class of type variables *) +module TypeVariable = String +module TypeVariableMap = Map.Make(TypeVariable) + + +(* + +Components: +* assignments (passive data structure). + Now: just a map from unification vars to types (pb: what about partial types?) + maybe just local assignments (allow only vars as children of pair(α,β)) +* constraint propagation: (buch of constraints) → (new constraints * assignments) + * sub-component: constraint selector (worklist / dynamic queries) + * sub-sub component: constraint normalizer: remove dupes and give structure + right now: union-find of unification vars + later: better database-like organisation of knowledge + * sub-sub component: lazy selector (don't re-try all selectors every time) + For now: just re-try everytime + * sub-component: propagation rule + For now: break pair(a, b) = pair(c, d) into a = c, b = d +* generalizer + For now: ? + +Workflow: + Start with empty assignments and structured database + Receive a new constraint + For each normalizer: + Use the pre-selector to see if it can be applied + Apply the normalizer, get some new items to insert in the structured database + For each propagator: + Use the selector to query the structured database and see if it can be applied + Apply the propagator, get some new constraints and assignments + Add the new assignments to the data structure. + + At some point (when?) + For each generalizer: + Use the generalizer's selector to see if it can be applied + Apply the generalizer to produce a new type, possibly with some ∀s injected + +*) + +open Core + +type structured_dbs = { + all_constraints : type_constraint_simpl list ; + aliases : unionfind ; + (* assignments (passive data structure). + Now: just a map from unification vars to types (pb: what about partial types?) + maybe just local assignments (allow only vars as children of pair(α,β)) *) + assignments : c_constructor_simpl TypeVariableMap.t ; + grouped_by_variable : constraints TypeVariableMap.t ; (* map from (unionfind) variables to constraints containing them *) + cycle_detection_toposort : unit ; (* example of structured db that we'll add later *) +} + +and constraints = { + constructor : c_constructor_simpl list ; + tc : c_typeclass_simpl list ; +} + +and c_constructor_simpl = { + tv : type_variable; + c_tag : constant_tag; + tv_list : type_variable list; +} +(* copy-pasted from core.ml *) +and c_const = (type_variable * type_value) +and c_equation = (type_value * type_value) +and c_typeclass_simpl = { + tc : typeclass ; + args : type_variable list ; +} +and type_constraint_simpl = + SC_Constructor of c_constructor_simpl (* α = ctor(β, …) *) + | SC_Alias of (type_variable * type_variable) (* α = β *) + | SC_Typeclass of c_typeclass_simpl (* TC(α, …) *) + +module UnionFindWrapper = struct + (* TODO: API for the structured db, to access it modulo unification variable aliases. *) + let get_constraints_related_to : type_variable -> structured_dbs -> constraints = + fun variable dbs -> + let variable , aliases = UF.get_or_set variable dbs.aliases in + let dbs = { dbs with aliases } in + match TypeVariableMap.find_opt variable dbs.grouped_by_variable with + Some l -> l + | None -> { + constructor = [] ; + tc = [] ; + } + let add_constraints_related_to : type_variable -> constraints -> structured_dbs -> structured_dbs = + fun variable c dbs -> + (* let (variable_repr , _height) , aliases = UF.get_or_set variable dbs.aliases in + let dbs = { dbs with aliases } in *) + let variable_repr , aliases = UF.get_or_set variable dbs.aliases in + let dbs = { dbs with aliases } in + let grouped_by_variable = TypeVariableMap.update variable_repr (function + None -> Some c + | Some x -> Some { + constructor = c.constructor @ x.constructor ; + tc = c.tc @ x.tc ; + }) + dbs.grouped_by_variable + in + let dbs = { dbs with grouped_by_variable } in + dbs + let merge_variables : type_variable -> type_variable -> structured_dbs -> structured_dbs = + fun variable_a variable_b dbs -> + let variable_repr_a , aliases = UF.get_or_set variable_a dbs.aliases in + let dbs = { dbs with aliases } in + let variable_repr_b , aliases = UF.get_or_set variable_b dbs.aliases in + let dbs = { dbs with aliases } in + let default d = function None -> d | Some y -> y in + let get_constraints ab = + TypeVariableMap.find_opt ab dbs.grouped_by_variable + |> default { constructor = [] ; tc = [] } in + let constraints_a = get_constraints variable_repr_a in + let constraints_b = get_constraints variable_repr_b in + let all_constraints = { + (* TODO: should be a Set.union, not @ *) + constructor = constraints_a.constructor @ constraints_b.constructor ; + tc = constraints_a.tc @ constraints_b.tc ; + } in + let grouped_by_variable = + TypeVariableMap.add variable_repr_a all_constraints dbs.grouped_by_variable in + let dbs = { dbs with grouped_by_variable} in + let grouped_by_variable = + TypeVariableMap.remove variable_repr_b dbs.grouped_by_variable in + let dbs = { dbs with grouped_by_variable} in + dbs +end + +(* sub-sub component: constraint normalizer: remove dupes and give structure + * right now: union-find of unification vars + * later: better database-like organisation of knowledge *) + +(* Each normalizer returns a *) +type ('a , 'b) normalizer = structured_dbs -> 'a -> (structured_dbs * 'b list) + +let normalizer_all_constraints : (type_constraint_simpl , type_constraint_simpl) normalizer = + fun dbs new_constraint -> + ({ dbs with all_constraints = new_constraint :: dbs.all_constraints } , [new_constraint]) + +let normalizer_grouped_by_variable : (type_constraint_simpl , type_constraint_simpl) normalizer = + fun dbs new_constraint -> + let store_constraint tvars constraints = + let aux dbs (tvar : type_variable) = + UnionFindWrapper.add_constraints_related_to tvar constraints dbs + in List.fold_left aux dbs tvars + in + let merge_constraints a b = + UnionFindWrapper.merge_variables a b dbs in + let dbs = match new_constraint with + SC_Constructor ({tv ; c_tag = _ ; tv_list} as c) -> store_constraint (tv :: tv_list) {constructor = [c] ; tc = []} + | SC_Typeclass ({tc = _ ; args} as c) -> store_constraint args {constructor = [] ; tc = [c]} + | SC_Alias (a , b) -> merge_constraints a b + in (dbs , [new_constraint]) + +(* Stores the first assinment ('a = ctor('b, …)) seen *) +let normalizer_assignments : (type_constraint_simpl , type_constraint_simpl) normalizer = + fun dbs new_constraint -> + match new_constraint with + | SC_Constructor ({tv ; c_tag = _ ; tv_list = _} as c) -> + let assignments = TypeVariableMap.update tv (function None -> Some c | e -> e) dbs.assignments in + let dbs = {dbs with assignments} in + (dbs , [new_constraint]) + | _ -> + (dbs , [new_constraint]) + +let rec normalizer_simpl : (type_constraint , type_constraint_simpl) normalizer = + fun dbs new_constraint -> + match new_constraint with + | C_equation (P_forall _, P_forall _) -> failwith "TODO" + | C_equation ((P_forall _ as a), (P_variable _ as b)) -> normalizer_simpl dbs (C_equation (b , a)) + | C_equation (P_forall _, P_constant _) -> failwith "TODO" + | C_equation (P_variable _, P_forall _) -> failwith "TODO" + | C_equation (P_variable a, P_variable b) -> (dbs , [SC_Alias (a, b)]) + | C_equation (P_variable a, P_constant (c_tag, args)) -> + let fresh_vars = List.map (fun _ -> Core.fresh_type_variable ()) args in + let fresh_eqns = List.map (fun (v,t) -> C_equation (P_variable v, t)) (List.combine fresh_vars args) in + let (dbs , recur) = List.fold_map_acc normalizer_simpl dbs fresh_eqns in + (dbs , [SC_Constructor {tv=a;c_tag;tv_list=fresh_vars}] @ List.flatten recur) + | C_equation (P_constant _, P_forall _) -> failwith "TODO" + | C_equation ((P_constant _ as a), (P_variable _ as b)) -> normalizer_simpl dbs (C_equation (b , a)) + | C_equation ((P_constant _ as a), (P_constant _ as b)) -> + (* break down c(args) = c'(args') into 'a = c(args) and 'a = c'(args') *) + let fresh = Core.fresh_type_variable () in + let (dbs , cs1) = normalizer_simpl dbs (C_equation (P_variable fresh, a)) in + let (dbs , cs2) = normalizer_simpl dbs (C_equation (P_variable fresh, b)) in + (dbs , cs1 @ cs2) (* TODO: O(n) concatenation! *) + | C_typeclass (args, tc) -> + (* break down TC(args) into TC('a, …) and ('a = arg) … *) + let fresh_vars = List.map (fun _ -> Core.fresh_type_variable ()) args in + let fresh_eqns = List.map (fun (v,t) -> C_equation (P_variable v, t)) (List.combine fresh_vars args) in + let (dbs , recur) = List.fold_map_acc normalizer_simpl dbs fresh_eqns in + (dbs, [SC_Typeclass { tc ; args = fresh_vars }] @ List.flatten recur) + | C_access_label (tv, label, result) -> let _todo = ignore (tv, label, result) in failwith "TODO" + +type ('state, 'elt) state_list_monad = { state: 'state ; list : 'elt list } +let lift_state_list_monad ~state ~list = { state ; list } +let lift f = + fun { state ; list } -> + let (new_state , new_lists) = List.fold_map_acc f state list in + { state = new_state ; list = List.flatten new_lists } + +(* TODO: move this to the List module *) +let named_fold_left f ~acc ~lst = List.fold_left (fun acc lst -> f ~acc ~lst) acc lst + +(* TODO: place the list of normalizers in a map *) +(* (\* cons for heterogeneous lists *\) + * type 'b f = { f : 'a . ('a -> 'b) -> 'a -> 'b } + * type ('hd , 'tl) hcons = { hd : 'hd ; tl : 'tl ; map : 'b . 'b f -> ('b , 'tl) hcons } + * let (+::) hd tl = { hd ; tl ; map = fun x -> } + * + * let list_of_normalizers = + * normalizer_simpl +:: + * normalizer_all_constraints +:: + * normalizer_assignments +:: + * normalizer_grouped_by_variable +:: + * () *) + +module Fun = struct let id x = x end (* in stdlib as of 4.08, we're in 4.07 for now *) + +let normalizers : type_constraint -> structured_dbs -> (structured_dbs , 'modified_constraint) state_list_monad = + fun new_constraint dbs -> + Fun.id + @@ lift normalizer_grouped_by_variable + @@ lift normalizer_assignments + @@ lift normalizer_all_constraints + @@ lift normalizer_simpl + @@ lift_state_list_monad ~state:dbs ~list:[new_constraint] + +(* sub-sub component: lazy selector (don't re-try all selectors every time) + * For now: just re-try everytime *) + +type todo = unit +let todo : todo = () +type 'old_constraint_type selector_input = 'old_constraint_type (* some info about the constraint just added, so that we know what to look for *) +type 'selector_output selector_outputs = + WasSelected of 'selector_output list + | WasNotSelected +type new_constraints = type_constraint list +type new_assignments = c_constructor_simpl list + +type ('old_constraint_type, 'selector_output) selector = 'old_constraint_type selector_input -> structured_dbs -> 'selector_output selector_outputs + +(* selector / propagation rule for breaking down composite types + * For now: do something with ('a = 'b) constraints. + + Or maybe this one should be a normalizer. *) + +(* selector / propagation rule for breaking down composite types + * For now: break pair(a, b) = pair(c, d) into a = c, b = d *) + +type output_break_ctor = < a_k_var : c_constructor_simpl ; a_k'_var' : c_constructor_simpl > +let selector_break_ctor : (type_constraint_simpl, output_break_ctor) selector = + (* find two rules with the shape a = k(var …) and a = k'(var' …) *) + fun todo dbs -> + match todo with + SC_Constructor c -> + let other_cs = (UnionFindWrapper.get_constraints_related_to c.tv dbs).constructor in + let cs_pairs = List.map (fun x -> object method a_k_var = c method a_k'_var' = x end) other_cs in + WasSelected cs_pairs + | SC_Alias _ -> WasNotSelected (* TODO: ??? *) + | SC_Typeclass _ -> WasNotSelected + +type 'selector_output propagator = 'selector_output -> structured_dbs -> new_constraints * new_assignments + +let propagator_break_ctor : output_break_ctor propagator = + fun selected dbs -> + let () = ignore (dbs) in (* this propagator doesn't need to use the dbs *) + let a = selected#a_k_var in + let b = selected#a_k'_var' in + (* produce constraints: *) + + (* a.tv = b.tv *) + let eq1 = C_equation (P_variable a.tv, P_variable b.tv) in + (* a.c_tag = b.c_tag *) + if a.c_tag <> b.c_tag then + failwith "type error: incompatible types, not same ctor (TODO error message)" + else + (* a.tv_list = b.tv_list *) + if List.length a.tv_list <> List.length b.tv_list then + failwith "type error: incompatible types, not same length (TODO error message)" + else + let eqs3 = List.map2 (fun aa bb -> C_equation (P_variable aa, P_variable bb)) a.tv_list b.tv_list in + let eqs = eq1 :: eqs3 in + (eqs , []) (* no new assignments *) + +let select_and_propagate : ('old_input, 'selector_output) selector -> 'selector_output propagator -> 'a -> structured_dbs -> new_constraints * new_assignments = + fun selector propagator -> + fun todo dbs -> + match selector todo dbs with + WasSelected selected_outputs -> + (* Call the propagation rule *) + let new_contraints_and_assignments = List.map (fun s -> propagator s dbs) selected_outputs in + let (new_constraints , new_assignments) = List.split new_contraints_and_assignments in + (* return so that the new constraints are pushed to some kind of work queue and the new assignments stored *) + (List.flatten new_constraints , List.flatten new_assignments) + | WasNotSelected -> + ([] , []) + +let select_and_propagate_break_ctor = select_and_propagate selector_break_ctor propagator_break_ctor + +let select_and_propagate_all' : type_constraint_simpl selector_input -> structured_dbs -> 'todo_result = + fun new_constraint dbs -> + let (new_constraints, new_assignments) = select_and_propagate_break_ctor new_constraint dbs in + let assignments = List.fold_left (fun acc ({tv;c_tag=_;tv_list=_} as ele) -> TypeVariableMap.update tv (function None -> Some ele | x -> x) acc) dbs.assignments new_assignments in + let dbs = { dbs with assignments } in + (* let blah2 = select_ … in … *) + (* We should try each selector in turn. If multiple selectors work, what should we do? *) + (new_constraints , dbs) + +let rec select_and_propagate_all : type_constraint selector_input list -> structured_dbs -> 'todo_result = + fun new_constraints dbs -> + match new_constraints with + | [] -> dbs + | new_constraint :: tl -> + let { state = dbs ; list = modified_constraints } = normalizers new_constraint dbs in + let (new_constraints' , dbs) = + List.fold_left + (fun (nc , dbs) c -> + let (new_constraints' , dbs) = select_and_propagate_all' c dbs in + (new_constraints' @ nc , dbs)) + ([] , dbs) + modified_constraints in + let new_constraints = new_constraints' @ tl in + select_and_propagate_all new_constraints dbs + +(* sub-component: constraint selector (worklist / dynamic queries) *) + +(* constraint propagation: (buch of constraints) → (new constraints * assignments) *) + + + + + +(* Below is a draft *) + +type state = { + (* when α-renaming x to y, we put them in the same union-find class *) + unification_vars : unionfind ; + + (* assigns a value to the representant in the unionfind *) + assignments : type_value TypeVariableMap.t ; + + (* constraints related to a type variable *) + constraints : constraints TypeVariableMap.t ; +} + +let initial_state : state = { + unification_vars = UF.empty ; + constraints = TypeVariableMap.empty ; + assignments = TypeVariableMap.empty ; +} + +(* let replace_var_in_state = fun (v : type_variable) (state : state) -> *) +(* let aux_tv : type_value -> _ = function *) +(* | P_forall (w , cs , tval) -> failwith "TODO" *) +(* | P_variable (w) -> *) +(* if w = v then *) +(* (*…*) *) +(* else *) +(* (*…*) *) +(* | P_constant (c , args) -> failwith "TODO" *) +(* | P_access_label (tv , label) -> failwith "TODO" in *) +(* let aux_tc tc = *) +(* List.map (fun l -> List.map aux_tv l) tc in *) +(* let aux : type_constraint -> _ = function *) +(* | C_equation (l , r) -> C_equation (aux_tv l , aux_tv r) *) +(* | C_typeclass (l , rs) -> C_typeclass (List.map aux_tv l , aux_tc rs) *) +(* in List.map aux state *) + +(* let check_equal a b = failwith "TODO" + * let check_same_length l1 l2 = failwith "TODO" + * + * let rec unify : type_value * type_value -> type_constraint list result = function + * | (P_variable v , P_constant (y , argsy)) -> + * failwith "TODO: replace v with the constant everywhere." + * | (P_constant (x , argsx) , P_variable w) -> + * failwith "TODO: " + * | (P_variable v , P_variable w) -> + * failwith "TODO: replace v with w everywhere" + * | (P_constant (x , argsx) , P_constant (y , argsy)) -> + * let%bind () = check_equal x y in + * let%bind () = check_same_length argsx argsy in + * let%bind _ = bind_map_list unify (List.combine argsx argsy) in + * ok [] + * | _ -> failwith "TODO" *) + +(* (\* unify a and b, possibly produce new constraints *\) *) +(* let () = ignore (a,b) in *) +(* ok [] *) + +(* This is the solver *) +let aggregate_constraints : state -> type_constraint list -> state result = fun state newc -> + (* TODO: Iterate over constraints *) + (* TODO: try to unify things: + if we have a = X and b = Y, try to unify X and Y *) + let _todo = ignore (state, newc) in + failwith "TODO" +(*let { constraints ; eqv } = state in + ok { constraints = constraints @ newc ; eqv }*) diff --git a/src/passes/4-typer/typer.ml b/src/passes/4-typer/typer.ml index 391239506..efa65ae1b 100644 --- a/src/passes/4-typer/typer.ml +++ b/src/passes/4-typer/typer.ml @@ -8,6 +8,8 @@ module SMap = O.SMap module Environment = O.Environment +module Solver = Solver + type environment = Environment.t module Errors = struct @@ -216,6 +218,7 @@ module Errors = struct ] in error ~data title message end + open Errors let rec type_program (p:I.program) : O.program result = @@ -238,6 +241,9 @@ and type_declaration env : I.declaration -> (environment * O.declaration option) let env' = Environment.add_type type_name tv env in ok (env', None) | Declaration_constant (name , tv_opt , expression) -> ( + (* + Determine the type of the expression and add it to the environment + *) let%bind tv'_opt = bind_map_option (evaluate_type env) tv_opt in let%bind ae' = trace (constant_declaration_error name expression tv'_opt) @@ @@ -340,6 +346,10 @@ and type_match : type i o . (environment -> i -> o result) -> environment -> O.t bind_map_list aux lst in ok (O.Match_variant (lst' , variant)) +(* + Recursively search the type_expression and return a result containing the + type_value at the leaves +*) and evaluate_type (e:environment) (t:I.type_expression) : O.type_value result = let return tv' = ok (make_t tv' (Some t)) in match t with @@ -782,6 +792,9 @@ let untype_literal (l:O.literal) : I.literal result = | Literal_address s -> ok (Literal_address s) | Literal_operation s -> ok (Literal_operation s) +(* + Tranform a Ast_typed expression into an ast_simplified matching +*) let rec untype_expression (e:O.annotated_expression) : (I.expression) result = let open I in let return e = ok e in @@ -849,6 +862,9 @@ let rec untype_expression (e:O.annotated_expression) : (I.expression) result = let%bind result = untype_expression result in return (e_let_in (binder , (Some tv)) rhs result) +(* + Tranform a Ast_typed matching into an ast_simplified matching +*) and untype_matching : type o i . (o -> i result) -> o O.matching -> (i I.matching) result = fun f m -> let open I in match m with diff --git a/src/passes/4-typer/typer.ml.old b/src/passes/4-typer/typer.ml.old new file mode 100644 index 000000000..dfd99cbbe --- /dev/null +++ b/src/passes/4-typer/typer.ml.old @@ -0,0 +1,879 @@ +open Trace + +module I = Ast_simplified +module O = Ast_typed +open O.Combinators + +module SMap = O.SMap + +module Environment = O.Environment + +type environment = Environment.t + +module Errors = struct + let unbound_type_variable (e:environment) (n:string) () = + let title = (thunk "unbound type variable") in + let message () = "" in + let data = [ + ("variable" , fun () -> Format.asprintf "%s" n) ; + (* TODO: types don't have srclocs for now. *) + (* ("location" , fun () -> Format.asprintf "%a" Location.pp (n.location)) ; *) + ("in" , fun () -> Format.asprintf "%a" Environment.PP.full_environment e) + ] in + error ~data title message () + + let unbound_variable (e:environment) (n:string) (loc:Location.t) () = + let title = (thunk "unbound variable") in + let message () = "" in + let data = [ + ("variable" , fun () -> Format.asprintf "%s" n) ; + ("environment" , fun () -> Format.asprintf "%a" Environment.PP.full_environment e) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + let match_empty_variant : type a . a I.matching -> Location.t -> unit -> _ = + fun matching loc () -> + let title = (thunk "match with no cases") in + let message () = "" in + let data = [ + ("variant" , fun () -> Format.asprintf "%a" I.PP.matching_type matching) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + let match_missing_case : type a . a I.matching -> Location.t -> unit -> _ = + fun matching loc () -> + let title = (thunk "missing case in match") in + let message () = "" in + let data = [ + ("variant" , fun () -> Format.asprintf "%a" I.PP.matching_type matching) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + let match_redundant_case : type a . a I.matching -> Location.t -> unit -> _ = + fun matching loc () -> + let title = (thunk "missing case in match") in + let message () = "" in + let data = [ + ("variant" , fun () -> Format.asprintf "%a" I.PP.matching_type matching) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + let unbound_constructor (e:environment) (n:string) (loc:Location.t) () = + let title = (thunk "unbound constructor") in + let message () = "" in + let data = [ + ("constructor" , fun () -> Format.asprintf "%s" n) ; + ("environment" , fun () -> Format.asprintf "%a" Environment.PP.full_environment e) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + let unrecognized_constant (n:string) (loc:Location.t) () = + let title = (thunk "unrecognized constant") in + let message () = "" in + let data = [ + ("constant" , fun () -> Format.asprintf "%s" n) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + let wrong_arity (n:string) (expected:int) (actual:int) (loc : Location.t) () = + let title () = "wrong arity" in + let message () = "" in + let data = [ + ("function" , fun () -> Format.asprintf "%s" n) ; + ("expected" , fun () -> Format.asprintf "%d" expected) ; + ("actual" , fun () -> Format.asprintf "%d" actual) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + let match_tuple_wrong_arity (expected:'a list) (actual:'b list) (loc:Location.t) () = + let title () = "matching tuple of different size" in + let message () = "" in + let data = [ + ("expected" , fun () -> Format.asprintf "%d" (List.length expected)) ; + ("actual" , fun () -> Format.asprintf "%d" (List.length actual)) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + (* TODO: this should be a trace_info? *) + let program_error (p:I.program) () = + let message () = "" in + let title = (thunk "typing program") in + let data = [ + ("program" , fun () -> Format.asprintf "%a" I.PP.program p) + ] in + error ~data title message () + + let constant_declaration_error (name:string) (ae:I.expr) (expected: O.type_expression option) () = + let title = (thunk "typing constant declaration") in + let message () = "" in + let data = [ + ("constant" , fun () -> Format.asprintf "%s" name) ; + ("expression" , fun () -> Format.asprintf "%a" I.PP.expression ae) ; + ("expected" , fun () -> + match expected with + None -> "(no annotation for the expected type)" + | Some expected -> Format.asprintf "%a" O.PP.type_expression expected) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp ae.location) + ] in + error ~data title message () + + let match_error : type a . ?msg:string -> expected: a I.matching -> actual: O.type_expression -> Location.t -> unit -> _ = + fun ?(msg = "") ~expected ~actual loc () -> + let title = (thunk "typing match") in + let message () = msg in + let data = [ + ("expected" , fun () -> Format.asprintf "%a" I.PP.matching_type expected); + ("actual" , fun () -> Format.asprintf "%a" O.PP.type_expression actual) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + let needs_annotation (e : I.expression) (case : string) () = + let title = (thunk "this expression must be annotated with its type") in + let message () = Format.asprintf "%s needs an annotation" case in + let data = [ + ("expression" , fun () -> Format.asprintf "%a" I.PP.expression e) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp e.location) + ] in + error ~data title message () + + let type_error_approximate ?(msg="") ~(expected: string) ~(actual: O.type_expression) ~(expression : I.expression) (loc:Location.t) () = + let title = (thunk "type error") in + let message () = msg in + let data = [ + ("expected" , fun () -> Format.asprintf "%s" expected); + ("actual" , fun () -> Format.asprintf "%a" O.PP.type_expression actual); + ("expression" , fun () -> Format.asprintf "%a" I.PP.expression expression) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + let type_error ?(msg="") ~(expected: O.type_expression) ~(actual: O.type_expression) ~(expression : I.expression) (loc:Location.t) () = + let title = (thunk "type error") in + let message () = msg in + let data = [ + ("expected" , fun () -> Format.asprintf "%a" O.PP.type_expression expected); + ("actual" , fun () -> Format.asprintf "%a" O.PP.type_expression actual); + ("expression" , fun () -> Format.asprintf "%a" I.PP.expression expression) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + let bad_tuple_index (index : int) (ae : I.expression) (t : O.type_expression) (loc:Location.t) () = + let title = (thunk "invalid tuple index") in + let message () = "" in + let data = [ + ("index" , fun () -> Format.asprintf "%d" index) ; + ("tuple_value" , fun () -> Format.asprintf "%a" I.PP.expression ae) ; + ("tuple_type" , fun () -> Format.asprintf "%a" O.PP.type_expression t) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + let bad_record_access (field : string) (ae : I.expression) (t : O.type_expression) (loc:Location.t) () = + let title = (thunk "invalid record field") in + let message () = "" in + let data = [ + ("field" , fun () -> Format.asprintf "%s" field) ; + ("record_value" , fun () -> Format.asprintf "%a" I.PP.expression ae) ; + ("tuple_type" , fun () -> Format.asprintf "%a" O.PP.type_expression t) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp loc) + ] in + error ~data title message () + + let not_supported_yet (message : string) (ae : I.expression) () = + let title = (thunk "not supported yet") in + let message () = message in + let data = [ + ("expression" , fun () -> Format.asprintf "%a" I.PP.expression ae) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp ae.location) + ] in + error ~data title message () + + let not_supported_yet_untranspile (message : string) (ae : O.expression) () = + let title = (thunk "not supported yet") in + let message () = message in + let data = [ + ("expression" , fun () -> Format.asprintf "%a" O.PP.expression ae) + ] in + error ~data title message () + + let constant_error loc lst tv_opt = + let title () = "typing constant" in + let message () = "" in + let data = [ + ("location" , fun () -> Format.asprintf "%a" Location.pp loc ) ; + ("argument_types" , fun () -> Format.asprintf "%a" PP_helpers.(list_sep Ast_typed.PP.type_expression (const " , ")) lst) ; + ("type_opt" , fun () -> Format.asprintf "%a" PP_helpers.(option Ast_typed.PP.type_expression) tv_opt) ; + ] in + error ~data title message +end +open Errors + +let rec type_program (p:I.program) : O.program result = + let aux (e, acc:(environment * O.declaration Location.wrap list)) (d:I.declaration Location.wrap) = + let%bind ed' = (bind_map_location (type_declaration e)) d in + let loc : 'a . 'a Location.wrap -> _ -> _ = fun x v -> Location.wrap ~loc:x.location v in + let (e', d') = Location.unwrap ed' in + match d' with + | None -> ok (e', acc) + | Some d' -> ok (e', loc ed' d' :: acc) + in + let%bind (_, lst) = + trace (fun () -> program_error p ()) @@ + bind_fold_list aux (Environment.full_empty, []) p in + ok @@ List.rev lst + +and type_declaration env : I.declaration -> (environment * O.declaration option) result = function + | Declaration_type (type_name , type_expression) -> + let%bind tv = evaluate_type env type_expression in + let env' = Environment.add_type type_name tv env in + ok (env', None) + | Declaration_constant (name , tv_opt , expression) -> ( + let%bind tv'_opt = bind_map_option (evaluate_type env) tv_opt in + let%bind ae' = + trace (constant_declaration_error name expression tv'_opt) @@ + type_expression ?tv_opt:tv'_opt env expression in + let env' = Environment.add_ez_ae name ae' env in + ok (env', Some (O.Declaration_constant ((make_n_e name ae') , (env , env')))) + ) + +and type_match : type i o . (environment -> i -> o result) -> environment -> O.type_expression -> i I.matching -> I.expression -> Location.t -> o O.matching result = + fun f e t i ae loc -> match i with + | Match_bool {match_true ; match_false} -> + let%bind _ = + trace_strong (match_error ~expected:i ~actual:t loc) + @@ get_t_bool t in + let%bind match_true = f e match_true in + let%bind match_false = f e match_false in + ok (O.Match_bool {match_true ; match_false}) + | Match_option {match_none ; match_some} -> + let%bind t_opt = + trace_strong (match_error ~expected:i ~actual:t loc) + @@ get_t_option t in + let%bind match_none = f e match_none in + let (n, b) = match_some in + let n' = n, t_opt in + let e' = Environment.add_ez_binder n t_opt e in + let%bind b' = f e' b in + ok (O.Match_option {match_none ; match_some = (n', b')}) + | Match_list {match_nil ; match_cons} -> + let%bind t_list = + trace_strong (match_error ~expected:i ~actual:t loc) + @@ get_t_list t in + let%bind match_nil = f e match_nil in + let (hd, tl, b) = match_cons in + let e' = Environment.add_ez_binder hd t_list e in + let e' = Environment.add_ez_binder tl t e' in + let%bind b' = f e' b in + ok (O.Match_list {match_nil ; match_cons = (hd, tl, b')}) + | Match_tuple (lst, b) -> + let%bind t_tuple = + trace_strong (match_error ~expected:i ~actual:t loc) + @@ get_t_tuple t in + let%bind lst' = + generic_try (match_tuple_wrong_arity t_tuple lst loc) + @@ (fun () -> List.combine lst t_tuple) in + let aux prev (name, tv) = Environment.add_ez_binder name tv prev in + let e' = List.fold_left aux e lst' in + let%bind b' = f e' b in + ok (O.Match_tuple (lst, b')) + | Match_variant lst -> + let%bind variant_opt = + let aux acc ((constructor_name , _) , _) = + let%bind (_ , variant) = + trace_option (unbound_constructor e constructor_name loc) @@ + Environment.get_constructor constructor_name e in + let%bind acc = match acc with + | None -> ok (Some variant) + | Some variant' -> ( + trace (type_error + ~msg:"in match variant" + ~expected:variant + ~actual:variant' + ~expression:ae + loc + ) @@ + Ast_typed.assert_type_expression_eq (variant , variant') >>? fun () -> + ok (Some variant) + ) in + ok acc in + trace (simple_info "in match variant") @@ + bind_fold_list aux None lst in + let%bind variant = + trace_option (match_empty_variant i loc) @@ + variant_opt in + let%bind () = + let%bind variant_cases' = + trace (match_error ~expected:i ~actual:t loc) + @@ Ast_typed.Combinators.get_t_sum variant in + let variant_cases = List.map fst @@ Map.String.to_kv_list variant_cases' in + let match_cases = List.map (Function.compose fst fst) lst in + let test_case = fun c -> + Assert.assert_true (List.mem c match_cases) + in + let%bind () = + trace_strong (match_missing_case i loc) @@ + bind_iter_list test_case variant_cases in + let%bind () = + trace_strong (match_redundant_case i loc) @@ + Assert.assert_true List.(length variant_cases = length match_cases) in + ok () + in + let%bind lst' = + let aux ((constructor_name , name) , b) = + let%bind (constructor , _) = + trace_option (unbound_constructor e constructor_name loc) @@ + Environment.get_constructor constructor_name e in + let e' = Environment.add_ez_binder name constructor e in + let%bind b' = f e' b in + ok ((constructor_name , name) , b') + in + bind_map_list aux lst in + ok (O.Match_variant (lst' , variant)) + +and evaluate_type (e:environment) (t:I.type_expression) : O.type_expression result = + let return tv' = ok (make_t tv' (Some t)) in + match t with + | T_function (a, b) -> + let%bind a' = evaluate_type e a in + let%bind b' = evaluate_type e b in + return (T_function (a', b')) + | T_tuple lst -> + let%bind lst' = bind_list @@ List.map (evaluate_type e) lst in + return (T_tuple lst') + | T_sum m -> + let aux k v prev = + let%bind prev' = prev in + let%bind v' = evaluate_type e v in + ok @@ SMap.add k v' prev' + in + let%bind m = SMap.fold aux m (ok SMap.empty) in + return (T_sum m) + | T_record m -> + let aux k v prev = + let%bind prev' = prev in + let%bind v' = evaluate_type e v in + ok @@ SMap.add k v' prev' + in + let%bind m = SMap.fold aux m (ok SMap.empty) in + return (T_record m) + | T_variable name -> + let%bind tv = + trace_option (unbound_type_variable e name) + @@ Environment.get_type_opt name e in + ok tv + | T_constant (cst, lst) -> + let%bind lst' = bind_list @@ List.map (evaluate_type e) lst in + return (T_constant(cst, lst')) + +and type_expression : environment -> ?tv_opt:O.type_expression -> I.expression -> O.annotated_expression result = fun e ?tv_opt ae -> + let module L = Logger.Stateful() in + let return expr tv = + let%bind () = + match tv_opt with + | None -> ok () + | Some tv' -> O.assert_type_expression_eq (tv' , tv) in + let location = Location.get_location ae in + ok @@ make_a_e ~location expr tv e in + let main_error = + let title () = "typing expression" in + let content () = "" in + let data = [ + ("expression" , fun () -> Format.asprintf "%a" I.PP.expression ae) ; + ("location" , fun () -> Format.asprintf "%a" Location.pp @@ Location.get_location ae) ; + ("misc" , fun () -> L.get ()) ; + ] in + error ~data title content in + trace main_error @@ + match Location.unwrap ae with + (* Basic *) + | E_failwith _ -> fail @@ needs_annotation ae "the failwith keyword" + | E_variable name -> + let%bind tv' = + trace_option (unbound_variable e name ae.location) + @@ Environment.get_opt name e in + return (E_variable name) tv'.type_expression + | E_literal (Literal_bool b) -> + return (E_literal (Literal_bool b)) (t_bool ()) + | E_literal Literal_unit | E_skip -> + return (E_literal (Literal_unit)) (t_unit ()) + | E_literal (Literal_string s) -> ( + L.log (Format.asprintf "literal_string option type: %a" PP_helpers.(option O.PP.type_expression) tv_opt) ; + match Option.map Ast_typed.get_type' tv_opt with + | Some (T_constant ("address" , [])) -> return (E_literal (Literal_address s)) (t_address ()) + | _ -> return (E_literal (Literal_string s)) (t_string ()) + ) + | E_literal (Literal_bytes s) -> + return (E_literal (Literal_bytes s)) (t_bytes ()) + | E_literal (Literal_int n) -> + return (E_literal (Literal_int n)) (t_int ()) + | E_literal (Literal_nat n) -> + return (E_literal (Literal_nat n)) (t_nat ()) + | E_literal (Literal_timestamp n) -> + return (E_literal (Literal_timestamp n)) (t_timestamp ()) + | E_literal (Literal_tez n) -> + return (E_literal (Literal_tez n)) (t_tez ()) + | E_literal (Literal_address s) -> + return (e_address s) (t_address ()) + | E_literal (Literal_operation op) -> + return (e_operation op) (t_operation ()) + (* Tuple *) + | E_tuple lst -> + let%bind lst' = bind_list @@ List.map (type_expression e) lst in + let tv_lst = List.map get_type_annotation lst' in + return (E_tuple lst') (t_tuple tv_lst ()) + | E_accessor (ae', path) -> + let%bind e' = type_expression e ae' in + let aux (prev:O.annotated_expression) (a:I.access) : O.annotated_expression result = + match a with + | Access_tuple index -> ( + let%bind tpl_tv = get_t_tuple prev.type_annotation in + let%bind tv = + generic_try (bad_tuple_index index ae' prev.type_annotation ae.location) + @@ (fun () -> List.nth tpl_tv index) in + return (E_tuple_accessor (prev , index)) tv + ) + | Access_record property -> ( + let%bind r_tv = get_t_record prev.type_annotation in + let%bind tv = + generic_try (bad_record_access property ae' prev.type_annotation ae.location) + @@ (fun () -> SMap.find property r_tv) in + return (E_record_accessor (prev , property)) tv + ) + | Access_map ae' -> ( + let%bind ae'' = type_expression e ae' in + let%bind (k , v) = get_t_map prev.type_annotation in + let%bind () = + Ast_typed.assert_type_expression_eq (k , get_type_annotation ae'') in + return (E_look_up (prev , ae'')) v + ) + in + trace (simple_info "accessing") @@ + bind_fold_list aux e' path + + (* Sum *) + | E_constructor (c, expr) -> + let%bind (c_tv, sum_tv) = + let error = + let title () = "no such constructor" in + let content () = + Format.asprintf "%s in:\n%a\n" + c O.Environment.PP.full_environment e + in + error title content in + trace_option error @@ + Environment.get_constructor c e in + let%bind expr' = type_expression e expr in + let%bind _assert = O.assert_type_expression_eq (expr'.type_annotation, c_tv) in + return (E_constructor (c , expr')) sum_tv + (* Record *) + | E_record m -> + let aux prev k expr = + let%bind expr' = type_expression e expr in + ok (SMap.add k expr' prev) + in + let%bind m' = bind_fold_smap aux (ok SMap.empty) m in + return (E_record m') (t_record (SMap.map get_type_annotation m') ()) + (* Data-structure *) + | E_list lst -> + let%bind lst' = bind_map_list (type_expression e) lst in + let%bind tv = + let aux opt c = + match opt with + | None -> ok (Some c) + | Some c' -> + let%bind _eq = Ast_typed.assert_type_expression_eq (c, c') in + ok (Some c') in + let%bind init = match tv_opt with + | None -> ok None + | Some ty -> + let%bind ty' = get_t_list ty in + ok (Some ty') in + let%bind ty = + let%bind opt = bind_fold_list aux init + @@ List.map get_type_annotation lst' in + trace_option (needs_annotation ae "empty list") opt in + ok (t_list ty ()) + in + return (E_list lst') tv + | E_set lst -> + let%bind lst' = bind_map_list (type_expression e) lst in + let%bind tv = + let aux opt c = + match opt with + | None -> ok (Some c) + | Some c' -> + let%bind _eq = Ast_typed.assert_type_expression_eq (c, c') in + ok (Some c') in + let%bind init = match tv_opt with + | None -> ok None + | Some ty -> + let%bind ty' = get_t_set ty in + ok (Some ty') in + let%bind ty = + let%bind opt = bind_fold_list aux init + @@ List.map get_type_annotation lst' in + trace_option (needs_annotation ae "empty set") opt in + ok (t_set ty ()) + in + return (E_set lst') tv + | E_map lst -> + let%bind lst' = bind_map_list (bind_map_pair (type_expression e)) lst in + let%bind tv = + let aux opt c = + match opt with + | None -> ok (Some c) + | Some c' -> + let%bind _eq = Ast_typed.assert_type_expression_eq (c, c') in + ok (Some c') in + let%bind key_type = + let%bind sub = + bind_fold_list aux None + @@ List.map get_type_annotation + @@ List.map fst lst' in + let%bind annot = bind_map_option get_t_map_key tv_opt in + trace (simple_info "empty map expression without a type annotation") @@ + O.merge_annotation annot sub (needs_annotation ae "this map literal") + in + let%bind value_type = + let%bind sub = + bind_fold_list aux None + @@ List.map get_type_annotation + @@ List.map snd lst' in + let%bind annot = bind_map_option get_t_map_value tv_opt in + trace (simple_info "empty map expression without a type annotation") @@ + O.merge_annotation annot sub (needs_annotation ae "this map literal") + in + ok (t_map key_type value_type ()) + in + return (E_map lst') tv + | E_lambda { + binder ; + input_type ; + output_type ; + result ; + } -> ( + let%bind input_type = + let%bind input_type = + (* Hack to take care of let_in introduced by `simplify/ligodity.ml` in ECase's hack *) + let default_action e () = fail @@ (needs_annotation e "the returned value") in + match input_type with + | Some ty -> ok ty + | None -> ( + match Location.unwrap result with + | I.E_let_in li -> ( + match Location.unwrap li.rhs with + | I.E_variable name when name = (fst binder) -> ( + match snd li.binder with + | Some ty -> ok ty + | None -> default_action li.rhs () + ) + | _ -> default_action li.rhs () + ) + | _ -> default_action result () + ) + in + evaluate_type e input_type in + let%bind output_type = + bind_map_option (evaluate_type e) output_type + in + let e' = Environment.add_ez_binder (fst binder) input_type e in + let%bind result = type_expression ?tv_opt:output_type e' result in + let output_type = result.type_annotation in + return (E_lambda {binder = fst binder;input_type;output_type;result}) (t_function input_type output_type ()) + ) + | E_constant (name, lst) -> + let%bind lst' = bind_list @@ List.map (type_expression e) lst in + let tv_lst = List.map get_type_annotation lst' in + let%bind (name', tv) = + type_constant name tv_lst tv_opt ae.location in + return (E_constant (name' , lst')) tv + | E_application (f, arg) -> + let%bind f' = type_expression e f in + let%bind arg = type_expression e arg in + let%bind tv = match f'.type_annotation.type_expression' with + | T_function (param, result) -> + let%bind _ = O.assert_type_expression_eq (param, arg.type_annotation) in + ok result + | _ -> + fail @@ type_error_approximate + ~expected:"should be a function type" + ~expression:f + ~actual:f'.type_annotation + f'.location + in + return (E_application (f' , arg)) tv + | E_look_up dsi -> + let%bind (ds, ind) = bind_map_pair (type_expression e) dsi in + let%bind (src, dst) = get_t_map ds.type_annotation in + let%bind _ = O.assert_type_expression_eq (ind.type_annotation, src) in + return (E_look_up (ds , ind)) (t_option dst ()) + (* Advanced *) + | E_matching (ex, m) -> ( + let%bind ex' = type_expression e ex in + match m with + (* Special case for assert-like failwiths. TODO: CLEAN THIS. *) + | I.Match_bool { match_false ; match_true } when I.is_e_failwith match_true -> ( + let%bind fw = I.get_e_failwith match_true in + let%bind fw' = type_expression e fw in + let%bind mf' = type_expression e match_false in + let t = get_type_annotation ex' in + let%bind () = + trace_strong (match_error ~expected:m ~actual:t ae.location) + @@ assert_t_bool t in + let%bind () = + trace_strong (match_error + ~msg:"matching not-unit on an assert" + ~expected:m + ~actual:t + ae.location) + @@ assert_t_unit (get_type_annotation mf') in + let mt' = make_a_e + (E_constant ("ASSERT_INFERRED" , [ex' ; fw'])) + (t_unit ()) + e + in + let m' = O.Match_bool { match_true = mt' ; match_false = mf' } in + return (O.E_matching (ex' , m')) (t_unit ()) + ) + | _ -> ( + let%bind m' = type_match (type_expression ?tv_opt:None) e ex'.type_annotation m ae ae.location in + let tvs = + let aux (cur:O.value O.matching) = + match cur with + | Match_bool { match_true ; match_false } -> [ match_true ; match_false ] + | Match_list { match_nil ; match_cons = (_ , _ , match_cons) } -> [ match_nil ; match_cons ] + | Match_option { match_none ; match_some = (_ , match_some) } -> [ match_none ; match_some ] + | Match_tuple (_ , match_tuple) -> [ match_tuple ] + | Match_variant (lst , _) -> List.map snd lst in + List.map get_type_annotation @@ aux m' in + let aux prec cur = + let%bind () = + match prec with + | None -> ok () + | Some cur' -> Ast_typed.assert_type_expression_eq (cur , cur') in + ok (Some cur) in + let%bind tv_opt = bind_fold_list aux None tvs in + let%bind tv = + trace_option (match_empty_variant m ae.location) @@ + tv_opt in + return (O.E_matching (ex', m')) tv + ) + ) + | E_sequence (a , b) -> + let%bind a' = type_expression e a in + let%bind b' = type_expression e b in + let a'_type_annot = get_type_annotation a' in + let%bind () = + trace_strong (type_error + ~msg:"first part of the sequence should be of unit type" + ~expected:(O.t_unit ()) + ~actual:a'_type_annot + ~expression:a + a'.location) @@ + Ast_typed.assert_type_expression_eq (t_unit () , a'_type_annot) in + return (O.E_sequence (a' , b')) (get_type_annotation b') + | E_loop (expr , body) -> + let%bind expr' = type_expression e expr in + let%bind body' = type_expression e body in + let t_expr' = get_type_annotation expr' in + let%bind () = + trace_strong (type_error + ~msg:"while condition isn't of type bool" + ~expected:(O.t_bool ()) + ~actual:t_expr' + ~expression:expr + expr'.location) @@ + Ast_typed.assert_type_expression_eq (t_bool () , t_expr') in + let t_body' = get_type_annotation body' in + let%bind () = + trace_strong (type_error + ~msg:"while body isn't of unit type" + ~expected:(O.t_unit ()) + ~actual:t_body' + ~expression:body + body'.location) @@ + Ast_typed.assert_type_expression_eq (t_unit () , t_body') in + return (O.E_loop (expr' , body')) (t_unit ()) + | E_assign (name , path , expr) -> + let%bind typed_name = + let%bind ele = Environment.get_trace name e in + ok @@ make_n_t name ele.type_expression in + let%bind (assign_tv , path') = + let aux : ((_ * O.access_path) as 'a) -> I.access -> 'a result = fun (prec_tv , prec_path) cur_path -> + match cur_path with + | Access_tuple index -> ( + let%bind tpl = get_t_tuple prec_tv in + let%bind tv' = + trace_option (bad_tuple_index index ae prec_tv ae.location) @@ + List.nth_opt tpl index in + ok (tv' , prec_path @ [O.Access_tuple index]) + ) + | Access_record property -> ( + let%bind m = get_t_record prec_tv in + let%bind tv' = + trace_option (bad_record_access property ae prec_tv ae.location) @@ + Map.String.find_opt property m in + ok (tv' , prec_path @ [O.Access_record property]) + ) + | Access_map _ -> + fail @@ not_supported_yet "assign expressions with maps are not supported yet" ae + in + bind_fold_list aux (typed_name.type_expression , []) path in + let%bind expr' = type_expression e expr in + let t_expr' = get_type_annotation expr' in + let%bind () = + trace_strong (type_error + ~msg:"type of the expression to assign doesn't match left-hand-side" + ~expected:assign_tv + ~actual:t_expr' + ~expression:expr + expr'.location) @@ + Ast_typed.assert_type_expression_eq (assign_tv , t_expr') in + return (O.E_assign (typed_name , path' , expr')) (t_unit ()) + | E_let_in {binder ; rhs ; result} -> + let%bind rhs_tv_opt = bind_map_option (evaluate_type e) (snd binder) in + let%bind rhs = type_expression ?tv_opt:rhs_tv_opt e rhs in + let e' = Environment.add_ez_declaration (fst binder) rhs e in + let%bind result = type_expression e' result in + return (E_let_in {binder = fst binder; rhs; result}) result.type_annotation + | E_annotation (expr , te) -> + let%bind tv = evaluate_type e te in + let%bind expr' = type_expression ~tv_opt:tv e expr in + let%bind type_annotation = + O.merge_annotation + (Some tv) + (Some expr'.type_annotation) + (internal_assertion_failure "merge_annotations (Some ...) (Some ...) failed") in + ok {expr' with type_annotation} + + +and type_constant (name:string) (lst:O.type_expression list) (tv_opt:O.type_expression option) (loc : Location.t) : (string * O.type_expression) result = + (* Constant poorman's polymorphism *) + let ct = Operators.Typer.constant_typers in + let%bind typer = + trace_option (unrecognized_constant name loc) @@ + Map.String.find_opt name ct in + trace (constant_error loc lst tv_opt) @@ + typer lst tv_opt + +let untype_type_expression (t:O.type_expression) : (I.type_expression) result = + match t.simplified with + | Some s -> ok s + | _ -> fail @@ internal_assertion_failure "trying to untype generated type" + +let untype_literal (l:O.literal) : I.literal result = + let open I in + match l with + | Literal_unit -> ok Literal_unit + | Literal_bool b -> ok (Literal_bool b) + | Literal_nat n -> ok (Literal_nat n) + | Literal_timestamp n -> ok (Literal_timestamp n) + | Literal_tez n -> ok (Literal_tez n) + | Literal_int n -> ok (Literal_int n) + | Literal_string s -> ok (Literal_string s) + | Literal_bytes b -> ok (Literal_bytes b) + | Literal_address s -> ok (Literal_address s) + | Literal_operation s -> ok (Literal_operation s) + +let rec untype_expression (e:O.annotated_expression) : (I.expression) result = + let open I in + let return e = ok e in + match e.expression with + | E_literal l -> + let%bind l = untype_literal l in + return (e_literal l) + | E_constant (n, lst) -> + let%bind lst' = bind_map_list untype_expression lst in + return (e_constant n lst') + | E_variable n -> + return (e_variable n) + | E_application (f, arg) -> + let%bind f' = untype_expression f in + let%bind arg' = untype_expression arg in + return (e_application f' arg') + | E_lambda {binder;input_type;output_type;result} -> + let%bind input_type = untype_type_expression input_type in + let%bind output_type = untype_type_expression output_type in + let%bind result = untype_expression result in + return (e_lambda binder (Some input_type) (Some output_type) result) + | E_tuple lst -> + let%bind lst' = bind_list + @@ List.map untype_expression lst in + return (e_tuple lst') + | E_tuple_accessor (tpl, ind) -> + let%bind tpl' = untype_expression tpl in + return (e_accessor tpl' [Access_tuple ind]) + | E_constructor (n, p) -> + let%bind p' = untype_expression p in + return (e_constructor n p') + | E_record r -> + let%bind r' = bind_smap + @@ SMap.map untype_expression r in + return (e_record r') + | E_record_accessor (r, s) -> + let%bind r' = untype_expression r in + return (e_accessor r' [Access_record s]) + | E_map m -> + let%bind m' = bind_map_list (bind_map_pair untype_expression) m in + return (e_map m') + | E_list lst -> + let%bind lst' = bind_map_list untype_expression lst in + return (e_list lst') + | E_set lst -> + let%bind lst' = bind_map_list untype_expression lst in + return (e_set lst') + | E_look_up dsi -> + let%bind (a , b) = bind_map_pair untype_expression dsi in + return (e_look_up a b) + | E_matching (ae, m) -> + let%bind ae' = untype_expression ae in + let%bind m' = untype_matching untype_expression m in + return (e_matching ae' m') + | E_failwith ae -> + let%bind ae' = untype_expression ae in + return (e_failwith ae') + | E_sequence _ + | E_loop _ + | E_assign _ -> fail @@ not_supported_yet_untranspile "not possible to untranspile statements yet" e.expression + | E_let_in {binder;rhs;result} -> + let%bind tv = untype_type_expression rhs.type_annotation in + let%bind rhs = untype_expression rhs in + let%bind result = untype_expression result in + return (e_let_in (binder , (Some tv)) rhs result) + +and untype_matching : type o i . (o -> i result) -> o O.matching -> (i I.matching) result = fun f m -> + let open I in + match m with + | Match_bool {match_true ; match_false} -> + let%bind match_true = f match_true in + let%bind match_false = f match_false in + ok @@ Match_bool {match_true ; match_false} + | Match_tuple (lst, b) -> + let%bind b = f b in + ok @@ Match_tuple (lst, b) + | Match_option {match_none ; match_some = (v, some)} -> + let%bind match_none = f match_none in + let%bind some = f some in + let match_some = fst v, some in + ok @@ Match_option {match_none ; match_some} + | Match_list {match_nil ; match_cons = (hd, tl, cons)} -> + let%bind match_nil = f match_nil in + let%bind cons = f cons in + let match_cons = hd, tl, cons in + ok @@ Match_list {match_nil ; match_cons} + | Match_variant (lst , _) -> + let aux ((a,b),c) = + let%bind c' = f c in + ok ((a,b),c') in + let%bind lst' = bind_map_list aux lst in + ok @@ Match_variant lst' diff --git a/src/passes/operators/dune b/src/passes/operators/dune index 0bd5db43d..f2125905a 100644 --- a/src/passes/operators/dune +++ b/src/passes/operators/dune @@ -5,6 +5,7 @@ simple-utils tezos-utils ast_typed + typesystem mini_c ) (preprocess diff --git a/src/passes/operators/operators.ml b/src/passes/operators/operators.ml index 47627a440..5eb1dfb91 100644 --- a/src/passes/operators/operators.ml +++ b/src/passes/operators/operators.ml @@ -219,6 +219,65 @@ module Typer = struct open Helpers.Typer open Ast_typed + module Operators_types = struct + open Typesystem.Shorthands + + let tc_subarg a b c = tc [a;b;c] [ (*TODO…*) ] + let tc_sizearg a = tc [a] [ [int] ] + let tc_packable a = tc [a] [ [int] ; [string] ; [bool] (*TODO…*) ] + let tc_timargs a b c = tc [a;b;c] [ [nat;nat;nat] ; [int;int;int] (*TODO…*) ] + let tc_divargs a b c = tc [a;b;c] [ (*TODO…*) ] + let tc_modargs a b c = tc [a;b;c] [ (*TODO…*) ] + let tc_addargs a b c = tc [a;b;c] [ (*TODO…*) ] + + let t_none = forall "a" @@ fun a -> option a + let t_sub = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_subarg a b c] => a --> b --> c (* TYPECLASS *) + let t_some = forall "a" @@ fun a -> a --> option a + let t_map_remove = forall2 "src" "dst" @@ fun src dst -> src --> map src dst --> map src dst + let t_map_add = forall2 "src" "dst" @@ fun src dst -> src --> dst --> map src dst --> map src dst + let t_map_update = forall2 "src" "dst" @@ fun src dst -> src --> option dst --> map src dst --> map src dst + let t_map_mem = forall2 "src" "dst" @@ fun src dst -> src --> map src dst --> bool + let t_map_find = forall2 "src" "dst" @@ fun src dst -> src --> map src dst --> dst + let t_map_find_opt = forall2 "src" "dst" @@ fun src dst -> src --> map src dst --> option dst + let t_map_fold = forall3 "src" "dst" "acc" @@ fun src dst acc -> ( ( (src * dst) * acc ) --> acc ) --> map src dst --> acc --> acc + let t_map_map = forall3 "k" "v" "result" @@ fun k v result -> ((k * v) --> result) --> map k v --> map k result + + (* TODO: the type of map_map_fold might be wrong, check it. *) + let t_map_map_fold = forall4 "k" "v" "acc" "dst" @@ fun k v acc dst -> ( ((k * v) * acc) --> acc * dst ) --> map k v --> (k * v) --> (map k dst * acc) + let t_map_iter = forall2 "k" "v" @@ fun k v -> ( (k * v) --> unit ) --> map k v --> unit + let t_size = forall_tc "c" @@ fun c -> [tc_sizearg c] => c --> nat (* TYPECLASS *) + let t_slice = nat --> nat --> string --> string + let t_failwith = string --> unit + let t_get_force = forall2 "src" "dst" @@ fun src dst -> src --> map src dst --> dst + let t_int = nat --> int + let t_bytes_pack = forall_tc "a" @@ fun a -> [tc_packable a] => a --> bytes (* TYPECLASS *) + let t_bytes_unpack = forall_tc "a" @@ fun a -> [tc_packable a] => bytes --> a (* TYPECLASS *) + let t_hash256 = bytes --> bytes + let t_hash512 = bytes --> bytes + let t_blake2b = bytes --> bytes + let t_hash_key = key --> key_hash + let t_check_signature = key --> signature --> bytes --> bool + let t_sender = address + let t_source = address + let t_unit = unit + let t_amount = tez + let t_address = address + let t_now = timestamp + let t_transaction = forall "a" @@ fun a -> a --> tez --> contract a --> operation + let t_get_contract = forall "a" @@ fun a -> contract a + let t_abs = int --> nat + let t_cons = forall "a" @@ fun a -> a --> list a --> list a + let t_assertion = bool --> unit + let t_times = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_timargs a b c] => a --> b --> c (* TYPECLASS *) + let t_div = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_divargs a b c] => a --> b --> c (* TYPECLASS *) + let t_mod = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_modargs a b c] => a --> b --> c (* TYPECLASS *) + let t_add = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_addargs a b c] => a --> b --> c (* TYPECLASS *) + let t_set_mem = forall "a" @@ fun a -> a --> set a --> bool + let t_set_add = forall "a" @@ fun a -> a --> set a --> set a + let t_set_remove = forall "a" @@ fun a -> a --> set a --> set a + let t_not = bool --> bool + end + let none = typer_0 "NONE" @@ fun tv_opt -> match tv_opt with | None -> simple_fail "untyped NONE" @@ -647,6 +706,7 @@ module Typer = struct get_contract ; neg ; abs ; + cons ; now ; slice ; address ; diff --git a/src/typesystem/core.ml b/src/typesystem/core.ml new file mode 100644 index 000000000..69a5d413c --- /dev/null +++ b/src/typesystem/core.ml @@ -0,0 +1,60 @@ + type type_variable = string + + let fresh_type_variable : ?name:string -> unit -> type_variable = + let id = ref 0 in + let inc () = id := !id + 1 in + fun ?name () -> + inc () ; + match name with + | None -> "type_variable_" ^ (string_of_int !id) + | Some name -> "tv_" ^ name ^ "_" ^ (string_of_int !id) + + + type constant_tag = + | C_arrow (* * -> * -> * *) + | C_option (* * -> * *) + | C_tuple (* * … -> * *) + | C_record (* ( label , * ) … -> * *) + | C_variant (* ( label , * ) … -> * *) + | C_map (* * -> * -> * *) + | C_list (* * -> * *) + | C_set (* * -> * *) + | C_unit (* * *) + | C_bool (* * *) + | C_string (* * *) + | C_nat (* * *) + | C_tez (* * *) + | C_timestamp (* * *) + | C_int (* * *) + | C_address (* * *) + | C_bytes (* * *) + | C_key_hash (* * *) + | C_key (* * *) + | C_signature (* * *) + | C_operation (* * *) + | C_contract (* * -> * *) + + type label = + | L_int of int + | L_string of string + + type type_value = + | P_forall of (type_variable * type_constraint list * type_value) + | P_variable of type_variable + | P_constant of (constant_tag * type_value list) + + and simple_c_constructor = (constant_tag * type_variable list) (* non-empty list *) + and simple_c_constant = (constant_tag) (* for type constructors that do not take arguments *) + and c_const = (type_variable * type_value) + and c_equation = (type_value * type_value) + and c_typeclass = (type_value list * typeclass) + and c_access_label = (type_value * label * type_variable) + + and type_constraint = + (* | C_assignment of (type_variable * type_pattern) *) + | C_equation of c_equation (* TVA = TVB *) + | C_typeclass of c_typeclass (* TVL ∈ TVLs, for now in extension, later add intensional (rule-based system for inclusion in the typeclass) *) + | C_access_label of c_access_label (* poor man's type-level computation to ensure that TV.label is type_variable *) + (* | … *) + + and typeclass = type_value list list diff --git a/src/typesystem/dune b/src/typesystem/dune new file mode 100644 index 000000000..d5e1deaf6 --- /dev/null +++ b/src/typesystem/dune @@ -0,0 +1,14 @@ +(library + (name typesystem) + (public_name ligo.typesystem) + (libraries + simple-utils + tezos-utils + ast_typed + mini_c + ) + (preprocess + (pps ppx_let) + ) + (flags (:standard -w +1..62-4-9-44-40-42-48-30@39@33 -open Simple_utils )) +) diff --git a/src/typesystem/shorthands.ml b/src/typesystem/shorthands.ml new file mode 100644 index 000000000..2bf16dd9c --- /dev/null +++ b/src/typesystem/shorthands.ml @@ -0,0 +1,62 @@ +open Core + +let tc type_vars allowed_list = + Core.C_typeclass (type_vars , allowed_list) + +let forall binder f = + let () = ignore binder in + let freshvar = fresh_type_variable () in + P_forall (freshvar , [] , f (P_variable freshvar)) + +let forall_tc binder f = + let () = ignore binder in + let freshvar = fresh_type_variable () in + let (tc, ty) = f (P_variable freshvar) in + P_forall (freshvar , tc , ty) + +let forall2 a b f = + forall a @@ fun a' -> + forall b @@ fun b' -> + f a' b' + +let forall3 a b c f = + forall a @@ fun a' -> + forall b @@ fun b' -> + forall c @@ fun c' -> + f a' b' c' + +let forall4 a b c d f = + forall a @@ fun a' -> + forall b @@ fun b' -> + forall c @@ fun c' -> + forall d @@ fun d' -> + f a' b' c' d' + +let forall3_tc a b c f = + forall a @@ fun a' -> + forall b @@ fun b' -> + forall_tc c @@ fun c' -> + f a' b' c' + +let (-->) arg ret = P_constant (C_arrow , [arg; ret]) +let (=>) tc ty = (tc , ty) +let option t = P_constant (C_option , [t]) +let pair a b = P_constant (C_tuple , [a; b]) +let map k v = P_constant (C_map , [k; v]) +let unit = P_constant (C_unit , []) +let list t = P_constant (C_list , [t]) +let set t = P_constant (C_set , [t]) +let bool = P_constant (C_bool , []) +let string = P_constant (C_string , []) +let nat = P_constant (C_nat , []) +let tez = P_constant (C_tez , []) +let timestamp = P_constant (C_timestamp , []) +let int = P_constant (C_int , []) +let address = P_constant (C_address , []) +let bytes = P_constant (C_bytes , []) +let key = P_constant (C_key , []) +let key_hash = P_constant (C_key_hash , []) +let signature = P_constant (C_signature , []) +let operation = P_constant (C_operation , []) +let contract t = P_constant (C_contract , [t]) +let ( * ) a b = pair a b diff --git a/src/typesystem/typesystem.ml b/src/typesystem/typesystem.ml new file mode 100644 index 000000000..b97e373e9 --- /dev/null +++ b/src/typesystem/typesystem.ml @@ -0,0 +1,2 @@ +module Core = Core +module Shorthands = Shorthands diff --git a/src/union_find/.PartitionMain.tag b/src/union_find/.PartitionMain.tag new file mode 100644 index 000000000..e69de29bb diff --git a/src/union_find/.links b/src/union_find/.links new file mode 100644 index 000000000..b79d096bc --- /dev/null +++ b/src/union_find/.links @@ -0,0 +1 @@ +../OCaml-build/Makefile diff --git a/src/union_find/LICENSE b/src/union_find/LICENSE new file mode 100644 index 000000000..33a225af0 --- /dev/null +++ b/src/union_find/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Christian Rinderknecht + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/union_find/Makefile.cfg b/src/union_find/Makefile.cfg new file mode 100644 index 000000000..13c016eb6 --- /dev/null +++ b/src/union_find/Makefile.cfg @@ -0,0 +1,4 @@ +SHELL := dash +BFLAGS := -strict-sequence -w +A-48-4 +#OCAMLC := ocamlcp +#OCAMLOPT := ocamloptp diff --git a/src/union_find/Partition.mli b/src/union_find/Partition.mli new file mode 100644 index 000000000..657b3c007 --- /dev/null +++ b/src/union_find/Partition.mli @@ -0,0 +1,64 @@ +(** This module offers the abstract data type of a partition of + classes of equivalent items (Union & Find). *) + +(** The items are of type [Item.t], that is, they have to obey + a total order, but also they must be printable to ease + debugging. The signature [Item] is the input signature of + the functor {!Partition.Make}. *) +module type Item = + sig + (** Type of items *) + type t + + (** Same convention as {!Pervasives.compare} *) + val compare : t -> t -> int + + val to_string : t -> string + end + +(** The module signature [S] is the output signature of the functor + {!Partition.Make}. *) +module type S = + sig + type item + type partition + type t = partition + + (** {1 Creation} *) + + (** The value [empty] is an empty partition. *) + val empty : partition + + (** The value of [equiv i j p] is the partition [p] extended with + the equivalence of items [i] and [j]. If both [i] and [j] are + already known to be equivalent, then [equiv i j p == p]. *) + val equiv : item -> item -> partition -> partition + + (** The value of [alias i j p] is the partition [p] extended with + the fact that item [i] is an alias of item [j]. This is the + same as [equiv i j p], except that it is guaranteed that the + item [i] is not the representative of its equivalence class in + [alias i j p]. *) + val alias : item -> item -> partition -> partition + + (** {1 Projection} *) + + (** The value of the call [repr i p] is the representative of item + [i] in the partition [p]. The built-in exception [Not_found] + is raised if [i] is not in [p]. *) + val repr : item -> partition -> item + + (** The side-effect of the call [print p] is the printing of the + partition [p] on standard output, based on [Ord.to_string]. *) + val print : partition -> unit + + (** {1 Predicates} *) + + (** The value of [is_equiv i j p] is [true] if, and only if, the + items [i] and [j] belong to the same equivalence class in the + partition [p], that is, [i] and [j] have the same + representative. *) + val is_equiv : item -> item -> partition -> bool + end + +module Make (Ord : Item) : S with type item = Ord.t diff --git a/src/union_find/Partition0.ml b/src/union_find/Partition0.ml new file mode 100644 index 000000000..968bb8dd4 --- /dev/null +++ b/src/union_find/Partition0.ml @@ -0,0 +1,47 @@ +(* Naive persistent implementation of Union/Find: O(n^2) worst case *) + +module Make (Item: Partition.Item) = + struct + + type item = Item.t + type repr = item (** Class representatives *) + + let equal i j = Item.compare i j = 0 + + module ItemMap = Map.Make (Item) + + type height = int + + type partition = item ItemMap.t + type t = partition + + let empty = ItemMap.empty + + let rec repr item partition = + let parent = ItemMap.find item partition in + if equal parent item + then item + else repr parent partition + + let is_equiv (i: item) (j: item) (p: partition) = + equal (repr i p) (repr j p) + + let get_or_set (i: item) (p: partition) : item * partition = + try repr i p, p with Not_found -> i, ItemMap.add i i p + + let equiv (i: item) (j :item) (p: partition) : partition = + let ri, p = get_or_set i p in + let rj, p = get_or_set j p in + if equal ri rj then p else ItemMap.add ri rj p + + let alias = equiv + + (* Printing *) + + let print p = + let print src dst = + Printf.printf "%s -> %s\n" + (Item.to_string src) (Item.to_string dst) + in ItemMap.iter print p + + end diff --git a/src/union_find/Partition1.ml b/src/union_find/Partition1.ml new file mode 100644 index 000000000..764d98d49 --- /dev/null +++ b/src/union_find/Partition1.ml @@ -0,0 +1,69 @@ +(* Persistent implementation of Union/Find with height-balanced + forests and without path compression: O(n*log(n)). + + In the definition of type [t], the height component is that of the + source, that is, if [ItemMap.find i m = (j,h)], then [h] is the + height of [i] (_not_ [j]). +*) + +module Make (Item: Partition.Item) = + struct + + type item = Item.t + type repr = item (** Class representatives *) + + let equal i j = Item.compare i j = 0 + + module ItemMap = Map.Make (Item) + + type height = int + + type partition = (item * height) ItemMap.t + type t = partition + + let empty = ItemMap.empty + + let rec seek (i: item) (p: partition) : repr * height = + let j, _ as i' = ItemMap.find i p in + if equal i j then i' else seek j p + + let repr item partition = fst (seek item partition) + + let is_equiv (i: item) (j: item) (p: partition) = + equal (repr i p) (repr j p) + + let get_or_set (i: item) (p: partition) = + try seek i p, p with + Not_found -> let i' = i,0 in (i', ItemMap.add i i' p) + + let equiv (i: item) (j: item) (p: partition) : partition = + let (ri,hi), p = get_or_set i p in + let (rj,hj), p = get_or_set j p in + let add = ItemMap.add in + if equal ri rj + then p + else if hi > hj + then add rj (ri,hj) p + else add ri (rj,hi) (if hi < hj then p else add rj (rj,hj+1) p) + + let alias (i: item) (j: item) (p: partition) : partition = + let (ri,hi), p = get_or_set i p in + let (rj,hj), p = get_or_set j p in + let add = ItemMap.add in + if equal ri rj + then p + else if hi = hj || equal ri i + then add ri (rj,hi) @@ add rj (rj, max hj (hi+1)) p + else if hi < hj then add ri (rj,hi) p + else add rj (ri,hj) p + + (* Printing *) + + let print (p: partition) = + let print i (j,hi) = + let _,hj = ItemMap.find j p in + Printf.printf "%s,%d -> %s,%d\n" + (Item.to_string i) hi (Item.to_string j) hj + in ItemMap.iter print p + + end diff --git a/src/union_find/Partition2.ml b/src/union_find/Partition2.ml new file mode 100644 index 000000000..e1372b2fd --- /dev/null +++ b/src/union_find/Partition2.ml @@ -0,0 +1,115 @@ +(** Persistent implementation of the Union/Find algorithm with + height-balanced forests and without path compression. *) + +module Make (Item: Partition.Item) = + struct + + type item = Item.t + type repr = item (** Class representatives *) + + let equal i j = Item.compare i j = 0 + + type height = int + + (** Each equivalence class is implemented by a Catalan tree linked + upwardly and otherwise is a link to another node. Those trees + are height-balanced. The type [node] implements nodes in those + trees. *) + type node = + Root of height + (** The value of [Root h] denotes the root of a tree, that is, + the representative of the associated class. The height [h] + is that of the tree, so a tree reduced to its root alone has + heigh 0. *) + + | Link of item * height + (** If not a root, a node is a link to another node. Because the + links are upward, that is, bottom-up, and we seek a purely + functional implementation, we need to uncouple the nodes and + the items here, so the first component of [Link] is an item, + not a node. That is why the type [node] is not recursive, + and called [node], not [tree]: to become a traversable tree, + it needs to be complemented by the type [partition] below to + associate items back to nodes. In order to follow a path + upward in the tree until the root, we start from a link node + giving us the next item, then find the node corresponding to + the item thanks to [partition], and again until we arrive at + the root. + + The height component is that of the source of the link, that + is, [h] is the height of the node linking to the node [Link + (j,h)], _not_ of [j], except when [equal i j]. *) + + module ItemMap = Map.Make (Item) + + (** The type [partition] implements a partition of classes of + equivalent items by means of a map from items to nodes of type + [node] in trees. *) + type partition = node ItemMap.t + + type t = partition + + let empty = ItemMap.empty + + let root (item, height) = ItemMap.add item (Root height) + + let link (src, height) dst = ItemMap.add src (Link (dst, height)) + + let rec seek (i: item) (p: partition) : repr * height = + match ItemMap.find i p with + Root hi -> i,hi + | Link (j,_) -> seek j p + + let repr item partition = fst (seek item partition) + + let is_equiv (i: item) (j: item) (p: partition) = + equal (repr i p) (repr j p) + + let get_or_set (i: item) (p: partition) = + try seek i p, p with + Not_found -> let n = i,0 in (n, root n p) + + let equiv (i: item) (j: item) (p: partition) : partition = + let (ri,hi as ni), p = get_or_set i p in + let (rj,hj as nj), p = get_or_set j p in + if equal ri rj + then p + else if hi > hj + then link nj ri p + else link ni rj (if hi < hj then p else root (rj, hj+1) p) + + (** The call [alias i j p] results in the same partition as [equiv + i j p], except that [i] is not the representative of its class + in [alias i j p] (whilst it may be in [equiv i j p]). + + This property is irrespective of the heights of the + representatives of [i] and [j], that is, of the trees + implementing their classes. If [i] is not a representative of + its class before calling [alias], then the height criteria is + applied (which, without the constraint above, would yield a + height-balanced new tree). *) + let alias (i: item) (j: item) (p: partition) : partition = + let (ri,hi as ni), p = get_or_set i p in + let (rj,hj as nj), p = get_or_set j p in + if equal ri rj + then p + else if hi = hj || equal ri i + then link ni rj @@ root (rj, max hj (hi+1)) p + else if hi < hj then link ni rj p + else link nj ri p + + (** {1 Printing} *) + + let print (p: partition) = + let print i node = + let hi, hj, j = + match node with + Root hi -> hi,hi,i + | Link (j,hi) -> + match ItemMap.find j p with + Root hj | Link (_,hj) -> hi,hj,j in + Printf.printf "%s,%d -> %s,%d\n" + (Item.to_string i) hi (Item.to_string j) hj + in ItemMap.iter print p + + end diff --git a/src/union_find/Partition3.ml b/src/union_find/Partition3.ml new file mode 100644 index 000000000..593292025 --- /dev/null +++ b/src/union_find/Partition3.ml @@ -0,0 +1,86 @@ +(* Destructive implementation of union/find with height-balanced + forests but without path compression: O(n*log(n)). *) + +module Make (Item: Partition.Item) = + struct + + type item = Item.t + type repr = item (** Class representatives *) + + let equal i j = Item.compare i j = 0 + + type height = int + + (** Each equivalence class is implemented by a Catalan tree linked + upwardly and otherwise is a link to another node. Those trees + are height-balanced. The type [node] implements nodes in those + trees. *) + type node = {item: item; mutable height: int; mutable parent: node} + + module ItemMap = Map.Make (Item) + + (** The type [partition] implements a partition of classes of + equivalent items by means of a map from items to nodes of type + [node] in trees. *) + type partition = node ItemMap.t + + type t = partition + + let empty = ItemMap.empty + + (** The function [repr] is faster than a persistent implementation + in the worst case because, in the latter case, the cost is O(log n) + for accessing each node in the path to the root, whereas, in the + former, only the access to the first node in the path incurs a cost + of O(log n) -- the other nodes are accessed in constant time by + following the [next] field of type [node]. *) + let seek (i: item) (p: partition) : node = + let rec find_root node = + if node.parent == node then node else find_root node.parent + in find_root (ItemMap.find i p) + + let repr item partition = (seek item partition).item + + let is_equiv (i: item) (j: item) (p: partition) = + equal (repr i p) (repr j p) + + let get_or_set item (p: partition) = + try seek item p, p with + Not_found -> let rec loop = {item; height=0; parent=loop} + in loop, ItemMap.add item loop p + + let link src dst = src.parent <- dst + + let equiv (i: item) (j: item) (p: partition) : partition = + let ni,p = get_or_set i p in + let nj,p = get_or_set j p in + let hi,hj = ni.height, nj.height in + let () = + if not (equal ni.item nj.item) + then if hi > hj + then link nj ni + else (link ni nj; nj.height <- max hj (hi+1)) + in p + + let alias (i: item) (j: item) (p: partition) : partition = + let ni,p = get_or_set i p in + let nj,p = get_or_set j p in + let hi,hj = ni.height, nj.height in + let () = + if not (equal ni.item nj.item) + then if hi = hj || equal ni.item i + then (link ni nj; nj.height <- max hj (hi+1)) + else if hi < hj then link ni nj + else link nj ni + in p + + (* Printing *) + + let print p = + let print _ node = + Printf.printf "%s,%d -> %s,%d\n" + (Item.to_string node.item) node.height + (Item.to_string node.parent.item) node.parent.height + in ItemMap.iter print p + + end diff --git a/src/union_find/PartitionMain.ml b/src/union_find/PartitionMain.ml new file mode 100644 index 000000000..4e69dbd87 --- /dev/null +++ b/src/union_find/PartitionMain.ml @@ -0,0 +1,40 @@ +module Int = + struct + type t = int + let compare (i: int) (j: int) = Pervasives.compare i j + let to_string = string_of_int + end + +module Test (Part: Partition.S with type item = Int.t) = + struct + open Part + + let () = empty + |> equiv 4 3 + |> equiv 3 8 + |> equiv 6 5 + |> equiv 9 4 + |> equiv 2 1 + |> equiv 8 9 + |> equiv 5 0 + |> equiv 7 2 + |> equiv 6 1 + |> equiv 1 0 + |> equiv 6 7 + |> equiv 8 0 + |> equiv 7 7 + |> equiv 10 10 + |> print + end + + +module Test0 = Test (Partition0.Make(Int)) +let () = print_newline () + +module Test1 = Test (Partition1.Make(Int)) +let () = print_newline () + +module Test2 = Test (Partition2.Make(Int)) +let () = print_newline () + +module Test3 = Test (Partition3.Make(Int)) diff --git a/src/union_find/README.md b/src/union_find/README.md new file mode 100644 index 000000000..16c7b5bf9 --- /dev/null +++ b/src/union_find/README.md @@ -0,0 +1,39 @@ +# Some implementations in OCaml of the Union/Find algorithm + +All modules implementing Union/Find can be coerced by the same +signature `Partition.S`. + +Note the function `alias` which is equivalent to `equiv`, but not +symmetric: `alias x y` means that `x` is an alias of `y`, which +translates in the present context as `x` not being the representative +of the equivalence class containing the equivalence between `x` and +`y`. The function `alias` is useful when managing aliases during the +static analyses of programming languages, so the representatives of +the classes are always the original object. + +The module `PartitionMain` tests each with the same equivalence +relations. + +## `Partition0.ml` + +This is a naive, persistent implementation of Union/Find featuring an +asymptotic worst case cost of O(n^2). + +## `Partition1.ml` + +This is a persistent implementation of Union/Find with height-balanced +forests and without path compression, featuring an asymptotic worst +case cost of O(n*log(n)). + +## `Partition2.ml` + +This is an alternate version of `Partition1.ml`, using a different +data type. + +## `Partition3.ml` + +This is a destructive implementation of Union/Find with +height-balanced forests but without path compression, featuring an +asymptotic worst case of O(n*log(n)). In practice, though, this +implementation should be faster than the previous ones, due to a +smaller multiplicative constant term. diff --git a/src/union_find/build.sh b/src/union_find/build.sh new file mode 100755 index 000000000..8453429fa --- /dev/null +++ b/src/union_find/build.sh @@ -0,0 +1,14 @@ +#!/bin/sh +set -x +ocamlfind ocamlc -strict-sequence -w +A-48-4 -c Partition.mli +ocamlfind ocamlopt -strict-sequence -w +A-48-4 -c Partition0.ml +ocamlfind ocamlopt -strict-sequence -w +A-48-4 -c Partition2.ml +ocamlfind ocamlopt -strict-sequence -w +A-48-4 -c Partition1.ml +ocamlfind ocamlopt -strict-sequence -w +A-48-4 -c Partition3.ml +ocamlfind ocamlopt -strict-sequence -w +A-48-4 -c Partition1.ml +ocamlfind ocamlopt -strict-sequence -w +A-48-4 -c Partition3.ml +ocamlfind ocamlopt -strict-sequence -w +A-48-4 -c Partition0.ml +ocamlfind ocamlopt -strict-sequence -w +A-48-4 -c Partition2.ml +ocamlfind ocamlopt -strict-sequence -w +A-48-4 -c PartitionMain.ml +ocamlfind ocamlopt -strict-sequence -w +A-48-4 -c PartitionMain.ml +ocamlfind ocamlopt -o PartitionMain.opt Partition0.cmx Partition1.cmx Partition2.cmx Partition3.cmx PartitionMain.cmx diff --git a/src/union_find/clean.sh b/src/union_find/clean.sh new file mode 100755 index 000000000..75ded7c50 --- /dev/null +++ b/src/union_find/clean.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +\rm -f *.cmi *.cmo *.cmx *.o *.byte *.opt diff --git a/src/union_find/dune b/src/union_find/dune new file mode 100644 index 000000000..a4c27e725 --- /dev/null +++ b/src/union_find/dune @@ -0,0 +1,16 @@ +(library + (name union_find) + (public_name ligo.union_find) + (wrapped false) ;; TODO: do we need this? + (modules Partition0 Partition1 Partition2 Partition3 Partition Union_find) + (modules_without_implementation Partition) +;; (preprocess +;; (pps simple-utils.ppx_let_generalized) +;; ) +;; (flags (:standard -w +1..62-4-9-44-40-42-48-30@39@33 -open Simple_utils )) + ) + +(test + (modules PartitionMain) + (libraries UnionFind) + (name PartitionMain)) diff --git a/src/union_find/union_find.ml b/src/union_find/union_find.ml new file mode 100644 index 000000000..17850f743 --- /dev/null +++ b/src/union_find/union_find.ml @@ -0,0 +1,2 @@ +module Partition = Partition +module Partition0 = Partition0 diff --git a/vendors/ligo-utils/simple-utils/trace.ml b/vendors/ligo-utils/simple-utils/trace.ml index 329203a46..1ae5360dd 100644 --- a/vendors/ligo-utils/simple-utils/trace.ml +++ b/vendors/ligo-utils/simple-utils/trace.ml @@ -667,6 +667,7 @@ let bind_map_pair f (a, b) = bind_pair (f a, f b) + (** Wraps a call that might trigger an exception in a result. *)