diff --git a/src/passes/4-typer-new/solver.ml b/src/passes/4-typer-new/solver.ml index 7175fab54..a81f04f3c 100644 --- a/src/passes/4-typer-new/solver.ml +++ b/src/passes/4-typer-new/solver.ml @@ -346,6 +346,16 @@ module Wrap = struct P_variable unification_body])) ] @ arg' @ body' , whole_expr + (* This is pretty much a wrapper for an n-ary function. *) + let constant : O.type_value -> T.type_value list -> (constraints * T.type_variable) = + fun f args -> + let whole_expr = Core.fresh_type_variable () in + let args' = List.map type_expression_to_type_value args in + let args_tuple = O.P_constant (C_tuple , args') in + O.[ + C_equation (f , P_constant (C_arrow , [args_tuple ; P_variable whole_expr])) + ] , whole_expr + end (* begin unionfind *) @@ -727,47 +737,7 @@ let selector_break_ctor : (type_constraint_simpl, output_break_ctor) selector = | SC_Poly _ -> WasNotSelected (* TODO: ??? (beware: symmetry) *) | SC_Typeclass _ -> WasNotSelected -let propagator_break_ctor : output_break_ctor propagator = - fun selected dbs -> - let () = ignore (dbs) in (* this propagator doesn't need to use the dbs *) - let a = selected.a_k_var in - let b = selected.a_k'_var' in - (* produce constraints: *) - - (* a.tv = b.tv *) - let eq1 = C_equation (P_variable a.tv, P_variable b.tv) in - (* a.c_tag = b.c_tag *) - if a.c_tag <> b.c_tag then - failwith "type error: incompatible types, not same ctor" - else - (* a.tv_list = b.tv_list *) - if List.length a.tv_list <> List.length b.tv_list then - failwith "type error: incompatible types, not same length" - else - let eqs3 = List.map2 (fun aa bb -> C_equation (P_variable aa, P_variable bb)) a.tv_list b.tv_list in - let eqs = eq1 :: eqs3 in - (eqs , []) (* no new assignments *) - -(* TODO : with our selectors, the selection depends on the order in which the constraints are added :-( :-( :-( :-( - We need to return a lazy stream of constraints. *) - -type output_specialize1 = { poly : c_poly_simpl ; a_k_var : c_constructor_simpl } - - -let ( (function - [] -> 1 - | hd2::tl2 -> - f hd1 hd2 - compare_list f tl1 tl2) - | [] -> (function [] -> 0 | _::_ -> -1) (* This follows the behaviour of Pervasives.compare for lists of different length *) -let compare_type_variable a b = - Var.compare a b -let compare_label = function - | L_int a -> (function L_int b -> Int.compare a b | L_string _ -> -1) - | L_string a -> (function L_int _ -> 1 | L_string b -> String.compare a b) +(* TODO: move this to a more appropriate place and/or auto-generate it. *) let compare_simple_c_constant = function | C_arrow -> (function (* N/A -> 1 *) @@ -866,6 +836,83 @@ let compare_simple_c_constant = function | C_chain_id -> 0 (* N/A -> -1 *) ) + +(* Using a pretty-printer from the PP.ml module creates a dependency + loop, so the one that we need temporarily for debugging purposes + has been copied here. *) +let debug_pp_constant : _ -> constant_tag -> unit = fun ppf c_tag -> + let ct = match c_tag with + | Core.C_arrow -> "arrow" + | Core.C_option -> "option" + | Core.C_tuple -> "tuple" + | Core.C_record -> failwith "record" + | Core.C_variant -> failwith "variant" + | Core.C_map -> "map" + | Core.C_big_map -> "big_map" + | Core.C_list -> "list" + | Core.C_set -> "set" + | Core.C_unit -> "unit" + | Core.C_bool -> "bool" + | Core.C_string -> "string" + | Core.C_nat -> "nat" + | Core.C_mutez -> "mutez" + | Core.C_timestamp -> "timestamp" + | Core.C_int -> "int" + | Core.C_address -> "address" + | Core.C_bytes -> "bytes" + | Core.C_key_hash -> "key_hash" + | Core.C_key -> "key" + | Core.C_signature -> "signature" + | Core.C_operation -> "operation" + | Core.C_contract -> "contract" + | Core.C_chain_id -> "chain_id" + in + Format.fprintf ppf "%s" ct + +let debug_pp_c_constructor_simpl ppf { tv; c_tag; tv_list } = + Format.fprintf ppf "CTOR %a %a(%a)" Var.pp tv debug_pp_constant c_tag PP_helpers.(list_sep Var.pp (const " , ")) tv_list + +let propagator_break_ctor : output_break_ctor propagator = + fun selected dbs -> + let () = ignore (dbs) in (* this propagator doesn't need to use the dbs *) + let a = selected.a_k_var in + let b = selected.a_k'_var' in + (* produce constraints: *) + + (* a.tv = b.tv *) + let eq1 = C_equation (P_variable a.tv, P_variable b.tv) in + (* a.c_tag = b.c_tag *) + if (compare_simple_c_constant a.c_tag b.c_tag) <> 0 then + failwith (Format.asprintf "type error: incompatible types, not same ctor %a vs. %a (compare returns %d)" debug_pp_c_constructor_simpl a debug_pp_c_constructor_simpl b (compare_simple_c_constant a.c_tag b.c_tag)) + else + (* a.tv_list = b.tv_list *) + if List.length a.tv_list <> List.length b.tv_list then + failwith "type error: incompatible types, not same length" + else + let eqs3 = List.map2 (fun aa bb -> C_equation (P_variable aa, P_variable bb)) a.tv_list b.tv_list in + let eqs = eq1 :: eqs3 in + (eqs , []) (* no new assignments *) + +(* TODO : with our selectors, the selection depends on the order in which the constraints are added :-( :-( :-( :-( + We need to return a lazy stream of constraints. *) + +type output_specialize1 = { poly : c_poly_simpl ; a_k_var : c_constructor_simpl } + + +let ( (function + [] -> 1 + | hd2::tl2 -> + f hd1 hd2 + compare_list f tl1 tl2) + | [] -> (function [] -> 0 | _::_ -> -1) (* This follows the behaviour of Pervasives.compare for lists of different length *) +let compare_type_variable a b = + Var.compare a b +let compare_label = function + | L_int a -> (function L_int b -> Int.compare a b | L_string _ -> -1) + | L_string a -> (function L_int _ -> 1 | L_string b -> String.compare a b) let rec compare_typeclass a b = compare_list (compare_list compare_type_value) a b and compare_type_value = function | P_forall { binder=a1; constraints=a2; body=a3 } -> (function diff --git a/src/passes/4-typer-new/typer.ml b/src/passes/4-typer-new/typer.ml index 7eb46e26e..0f75c8bb6 100644 --- a/src/passes/4-typer-new/typer.ml +++ b/src/passes/4-typer-new/typer.ml @@ -889,6 +889,7 @@ and type_expression : environment -> Solver.state -> ?tv_opt:O.type_value -> I.e let e' = Environment.add_ez_binder (fst binder) fresh e in let%bind (result , state') = type_expression e' state result in + let () = Printf.printf "this does not make use of the typed body, this code sounds buggy." in let wrapped = Wrap.lambda fresh input_type' output_type' in return_wrapped (E_lambda {binder = fst binder; body=result}) (* TODO: is the type of the entire lambda enough to access the input_type=fresh; ? *) @@ -897,8 +898,17 @@ and type_expression : environment -> Solver.state -> ?tv_opt:O.type_value -> I.e | E_constant (name, lst) -> let () = ignore (name , lst) in - let _t = Operators.Typer.Operators_types.constant_type name in - Pervasives.failwith (Format.asprintf "TODO: E_constant (%a(%a))" Stage_common.PP.constant name (Format.pp_print_list Ast_simplified.PP.expression) lst) + let%bind t = Operators.Typer.Operators_types.constant_type name in + let aux acc expr = + let (lst , state) = acc in + let%bind (expr, state') = type_expression e state expr in + ok (expr::lst , state') in + let%bind (lst , state') = bind_fold_list aux ([], state) lst in + let lst_annot = List.map (fun (x : O.value) -> x.type_annotation) lst in + let wrapped = Wrap.constant t lst_annot in + return_wrapped + (E_constant (name, lst)) + state' wrapped (* let%bind lst' = bind_list @@ List.map (type_expression e) lst in let tv_lst = List.map get_type_annotation lst' in diff --git a/src/passes/operators/operators.ml b/src/passes/operators/operators.ml index 49f693030..819645c85 100644 --- a/src/passes/operators/operators.ml +++ b/src/passes/operators/operators.ml @@ -324,51 +324,52 @@ module Typer = struct let tc_addargs a b c = tc [a;b;c] [ (*TODO…*) ] let t_none = forall "a" @@ fun a -> option a - let t_sub = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_subarg a b c] => a --> b --> c (* TYPECLASS *) + let t_sub = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_subarg a b c] => tuple2 a b --> c (* TYPECLASS *) let t_some = forall "a" @@ fun a -> a --> option a - let t_map_remove = forall2 "src" "dst" @@ fun src dst -> src --> map src dst --> map src dst - let t_map_add = forall2 "src" "dst" @@ fun src dst -> src --> dst --> map src dst --> map src dst - let t_map_update = forall2 "src" "dst" @@ fun src dst -> src --> option dst --> map src dst --> map src dst - let t_map_mem = forall2 "src" "dst" @@ fun src dst -> src --> map src dst --> bool - let t_map_find = forall2 "src" "dst" @@ fun src dst -> src --> map src dst --> dst - let t_map_find_opt = forall2 "src" "dst" @@ fun src dst -> src --> map src dst --> option dst - let t_map_fold = forall3 "src" "dst" "acc" @@ fun src dst acc -> ( ( (src * dst) * acc ) --> acc ) --> map src dst --> acc --> acc - let t_map_map = forall3 "k" "v" "result" @@ fun k v result -> ((k * v) --> result) --> map k v --> map k result + let t_map_remove = forall2 "src" "dst" @@ fun src dst -> tuple2 src (map src dst) --> map src dst + let t_map_add = forall2 "src" "dst" @@ fun src dst -> tuple3 src dst (map src dst) --> map src dst + let t_map_update = forall2 "src" "dst" @@ fun src dst -> tuple3 src (option dst) (map src dst) --> map src dst + let t_map_mem = forall2 "src" "dst" @@ fun src dst -> tuple2 src (map src dst) --> bool + let t_map_find = forall2 "src" "dst" @@ fun src dst -> tuple2 src (map src dst) --> dst + let t_map_find_opt = forall2 "src" "dst" @@ fun src dst -> tuple2 src (map src dst) --> option dst + let t_map_fold = forall3 "src" "dst" "acc" @@ fun src dst acc -> tuple3 ( ( (src * dst) * acc ) --> acc ) (map src dst) acc --> acc + let t_map_map = forall3 "k" "v" "result" @@ fun k v result -> tuple2 ((k * v) --> result) (map k v) --> map k result (* TODO: the type of map_map_fold might be wrong, check it. *) - let t_map_map_fold = forall4 "k" "v" "acc" "dst" @@ fun k v acc dst -> ( ((k * v) * acc) --> acc * dst ) --> map k v --> (k * v) --> (map k dst * acc) - let t_map_iter = forall2 "k" "v" @@ fun k v -> ( (k * v) --> unit ) --> map k v --> unit - let t_size = forall_tc "c" @@ fun c -> [tc_sizearg c] => c --> nat (* TYPECLASS *) - let t_slice = nat --> nat --> string --> string - let t_failwith = string --> unit - let t_get_force = forall2 "src" "dst" @@ fun src dst -> src --> map src dst --> dst - let t_int = nat --> int - let t_bytes_pack = forall_tc "a" @@ fun a -> [tc_packable a] => a --> bytes (* TYPECLASS *) - let t_bytes_unpack = forall_tc "a" @@ fun a -> [tc_packable a] => bytes --> a (* TYPECLASS *) - let t_hash256 = bytes --> bytes - let t_hash512 = bytes --> bytes - let t_blake2b = bytes --> bytes - let t_hash_key = key --> key_hash - let t_check_signature = key --> signature --> bytes --> bool - let t_sender = address - let t_source = address - let t_unit = unit - let t_amount = mutez - let t_address = address - let t_now = timestamp - let t_transaction = forall "a" @@ fun a -> a --> mutez --> contract a --> operation - let t_get_contract = forall "a" @@ fun a -> contract a - let t_abs = int --> nat - let t_cons = forall "a" @@ fun a -> a --> list a --> list a - let t_assertion = bool --> unit - let t_times = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_timargs a b c] => a --> b --> c (* TYPECLASS *) - let t_div = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_divargs a b c] => a --> b --> c (* TYPECLASS *) - let t_mod = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_modargs a b c] => a --> b --> c (* TYPECLASS *) - let t_add = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_addargs a b c] => a --> b --> c (* TYPECLASS *) - let t_set_mem = forall "a" @@ fun a -> a --> set a --> bool - let t_set_add = forall "a" @@ fun a -> a --> set a --> set a - let t_set_remove = forall "a" @@ fun a -> a --> set a --> set a - let t_not = bool --> bool + let t_map_map_fold = forall4 "k" "v" "acc" "dst" @@ fun k v acc dst -> tuple3 ( ((k * v) * acc) --> acc * dst ) (map k v) (k * v) --> (map k dst * acc) + let t_map_iter = forall2 "k" "v" @@ fun k v -> tuple2 ( (k * v) --> unit ) (map k v) --> unit + let t_size = forall_tc "c" @@ fun c -> [tc_sizearg c] => tuple1 c --> nat (* TYPECLASS *) + let t_slice = tuple3 nat nat string --> string + let t_failwith = tuple1 string --> unit + let t_get_force = forall2 "src" "dst" @@ fun src dst -> tuple2 src (map src dst) --> dst + let t_int = tuple1 nat --> int + let t_bytes_pack = forall_tc "a" @@ fun a -> [tc_packable a] => tuple1 a --> bytes (* TYPECLASS *) + let t_bytes_unpack = forall_tc "a" @@ fun a -> [tc_packable a] => tuple1 bytes --> a (* TYPECLASS *) + let t_hash256 = tuple1 bytes --> bytes + let t_hash512 = tuple1 bytes --> bytes + let t_blake2b = tuple1 bytes --> bytes + let t_hash_key = tuple1 key --> key_hash + let t_check_signature = tuple3 key signature bytes --> bool + let t_chain_id = tuple0 --> chain_id + let t_sender = tuple0 --> address + let t_source = tuple0 --> address + let t_unit = tuple0 --> unit + let t_amount = tuple0 --> mutez + let t_address = tuple0 --> address + let t_now = tuple0 --> timestamp + let t_transaction = forall "a" @@ fun a -> tuple3 a mutez (contract a) --> operation + let t_get_contract = forall "a" @@ fun a -> tuple0 --> contract a + let t_abs = tuple1 int --> nat + let t_cons = forall "a" @@ fun a -> a --> tuple1 (list a) --> list a + let t_assertion = tuple1 bool --> unit + let t_times = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_timargs a b c] => tuple2 a b --> c (* TYPECLASS *) + let t_div = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_divargs a b c] => tuple2 a b --> c (* TYPECLASS *) + let t_mod = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_modargs a b c] => tuple2 a b --> c (* TYPECLASS *) + let t_add = forall3_tc "a" "b" "c" @@ fun a b c -> [tc_addargs a b c] => tuple2 a b --> c (* TYPECLASS *) + let t_set_mem = forall "a" @@ fun a -> tuple2 a (set a) --> bool + let t_set_add = forall "a" @@ fun a -> tuple2 a (set a) --> set a + let t_set_remove = forall "a" @@ fun a -> tuple2 a (set a) --> set a + let t_not = tuple1 bool --> bool let constant_type : constant -> Typesystem.Core.type_value result = function | C_INT -> ok @@ t_int ; @@ -442,7 +443,7 @@ module Typer = struct | C_BLAKE2b -> ok @@ t_blake2b ; | C_HASH_KEY -> ok @@ t_hash_key ; | C_CHECK_SIGNATURE -> ok @@ t_check_signature ; - | C_CHAIN_ID -> ok @@ failwith "t_chain_id" ; + | C_CHAIN_ID -> ok @@ t_chain_id ; (*BLOCKCHAIN *) | C_CONTRACT -> ok @@ t_get_contract ; | C_CONTRACT_ENTRYPOINT -> ok @@ failwith "t_get_entrypoint" ; diff --git a/src/stages/typesystem/shorthands.ml b/src/stages/typesystem/shorthands.ml index 44af59ad9..109b7b15b 100644 --- a/src/stages/typesystem/shorthands.ml +++ b/src/stages/typesystem/shorthands.ml @@ -54,6 +54,7 @@ let mutez = P_constant (C_mutez , []) let timestamp = P_constant (C_timestamp , []) let int = P_constant (C_int , []) let address = P_constant (C_address , []) +let chain_id = P_constant (C_chain_id , []) let bytes = P_constant (C_bytes , []) let key = P_constant (C_key , []) let key_hash = P_constant (C_key_hash , []) @@ -61,3 +62,9 @@ let signature = P_constant (C_signature , []) let operation = P_constant (C_operation , []) let contract t = P_constant (C_contract , [t]) let ( * ) a b = pair a b + +(* These are used temporarily to de-curry functions that correspond to Michelson operators *) +let tuple0 = P_constant (C_tuple , []) +let tuple1 a = P_constant (C_tuple , [a]) +let tuple2 a b = P_constant (C_tuple , [a; b]) +let tuple3 a b c = P_constant (C_tuple , [a; b; c])