From 7b9d861a34a07bf3885dfef9c500c1cdc42b7dfc Mon Sep 17 00:00:00 2001 From: galfour Date: Fri, 19 Jul 2019 12:13:09 +0200 Subject: [PATCH] type new operators --- src/ast_typed/combinators.ml | 8 ++ src/compiler/compiler_program.ml | 5 + src/contracts/arithmetic.ligo | 3 + src/operators/helpers.ml | 35 ++++++ src/operators/operators.ml | 206 +++++++++++++++++++++++++------ src/simplify/pascaligo.ml | 11 +- src/test/integration_tests.ml | 1 - 7 files changed, 222 insertions(+), 47 deletions(-) diff --git a/src/ast_typed/combinators.ml b/src/ast_typed/combinators.ml index 1b4a1926c..ec745fabc 100644 --- a/src/ast_typed/combinators.ml +++ b/src/ast_typed/combinators.ml @@ -141,6 +141,11 @@ let get_t_map (t:type_value) : (type_value * type_value) result = | T_constant ("map", [k;v]) -> ok (k, v) | _ -> simple_fail "get: not a map" +let get_t_big_map (t:type_value) : (type_value * type_value) result = + match t.type_value' with + | T_constant ("big_map", [k;v]) -> ok (k, v) + | _ -> simple_fail "get: not a big_map" + let get_t_map_key : type_value -> type_value result = fun t -> let%bind (key , _) = get_t_map t in ok key @@ -154,6 +159,7 @@ let assert_t_map = fun t -> ok () let is_t_map = Function.compose to_bool get_t_map +let is_t_big_map = Function.compose to_bool get_t_big_map let assert_t_tez : type_value -> unit result = get_t_tez let assert_t_key = get_t_key @@ -165,8 +171,10 @@ let assert_t_list t = ok () let is_t_list = Function.compose to_bool get_t_list +let is_t_set = Function.compose to_bool get_t_set let is_t_nat = Function.compose to_bool get_t_nat let is_t_string = Function.compose to_bool get_t_string +let is_t_bytes = Function.compose to_bool get_t_bytes let is_t_int = Function.compose to_bool get_t_int let assert_t_bytes = fun t -> diff --git a/src/compiler/compiler_program.ml b/src/compiler/compiler_program.ml index 813def75c..f7f0f50a3 100644 --- a/src/compiler/compiler_program.ml +++ b/src/compiler/compiler_program.ml @@ -21,6 +21,11 @@ let get_predicate : string -> type_value -> expression list -> predicate result let%bind m_ty = Compiler_type.type_ ty' in ok @@ simple_unary @@ prim ~children:[m_ty] I_NONE ) + | "NIL" -> ( + let%bind ty' = Mini_c.get_t_list ty in + let%bind m_ty = Compiler_type.type_ ty' in + ok @@ simple_unary @@ prim ~children:[m_ty] I_NIL + ) | "UNPACK" -> ( let%bind ty' = Mini_c.get_t_option ty in let%bind m_ty = Compiler_type.type_ ty' in diff --git a/src/contracts/arithmetic.ligo b/src/contracts/arithmetic.ligo index 25b756b04..efaa0e62b 100644 --- a/src/contracts/arithmetic.ligo +++ b/src/contracts/arithmetic.ligo @@ -15,3 +15,6 @@ function div_op (const n : int) : int is function int_op (const n : nat) : int is block { skip } with int(n) + +function neg_op (const n : int) : int is + begin skip end with -n diff --git a/src/operators/helpers.ml b/src/operators/helpers.ml index 7982ddde0..8fd18a16f 100644 --- a/src/operators/helpers.ml +++ b/src/operators/helpers.ml @@ -70,6 +70,33 @@ module Typer = struct | _ -> fail @@ wrong_param_number s 3 lst let typer_3 name f : typer = (name , typer'_3 name f) + let typer'_4 : name -> (type_value -> type_value -> type_value -> type_value -> type_value result) -> typer' = fun s f lst _ -> + match lst with + | [ a ; b ; c ; d ] -> ( + let%bind tv' = f a b c d in + ok (s , tv') + ) + | _ -> fail @@ wrong_param_number s 4 lst + let typer_4 name f : typer = (name , typer'_4 name f) + + let typer'_5 : name -> (type_value -> type_value -> type_value -> type_value -> type_value -> type_value result) -> typer' = fun s f lst _ -> + match lst with + | [ a ; b ; c ; d ; e ] -> ( + let%bind tv' = f a b c d e in + ok (s , tv') + ) + | _ -> fail @@ wrong_param_number s 5 lst + let typer_5 name f : typer = (name , typer'_5 name f) + + let typer'_6 : name -> (type_value -> type_value -> type_value -> type_value -> type_value -> type_value -> type_value result) -> typer' = fun s f lst _ -> + match lst with + | [ a ; b ; c ; d ; e ; f_ ] -> ( + let%bind tv' = f a b c d e f_ in + ok (s , tv') + ) + | _ -> fail @@ wrong_param_number s 6 lst + let typer_6 name f : typer = (name , typer'_6 name f) + let constant name cst = typer_0 name (fun _ -> ok cst) open Combinators @@ -77,6 +104,8 @@ module Typer = struct let eq_1 a cst = type_value_eq (a , cst) let eq_2 (a , b) cst = type_value_eq (a , cst) && type_value_eq (b , cst) + let assert_eq_1 a b = Assert.assert_true (eq_1 a b) + let comparator : string -> typer = fun s -> typer_2 s @@ fun a b -> let%bind () = trace_strong (error_uncomparable_types a b) @@ @@ -114,8 +143,14 @@ module Compiler = struct | Unary of michelson | Binary of michelson | Ternary of michelson + | Tetrary of michelson + | Pentary of michelson + | Hexary of michelson let simple_constant c = Constant c let simple_unary c = Unary c let simple_binary c = Binary c let simple_ternary c = Ternary c + let simple_tetrary c = Tetrary c + let simple_pentary c = Pentary c + let simple_hexary c = Hexary c end diff --git a/src/operators/operators.ml b/src/operators/operators.ml index 70fc01986..6a85aaa37 100644 --- a/src/operators/operators.ml +++ b/src/operators/operators.ml @@ -91,7 +91,7 @@ module Simplify = struct module Ligodity = struct let constants = [ ("assert" , "ASSERT") ; - + ("Current.balance", "BALANCE") ; ("balance", "BALANCE") ; ("Current.time", "NOW") ; @@ -132,7 +132,7 @@ module Simplify = struct ("Map.update" , "MAP_UPDATE") ; ("Map.add" , "MAP_ADD") ; ("Map.remove" , "MAP_REMOVE") ; - + ("String.length", "SIZE") ; ("String.size", "SIZE") ; ("String.slice", "SLICE") ; @@ -196,6 +196,8 @@ module Typer = struct then ok @@ t_int () else if (eq_2 (a , b) (t_timestamp ())) then ok @@ t_int () else + if (eq_1 a (t_timestamp ()) && eq_1 b (t_int ())) + then ok @@ t_timestamp () else if (eq_2 (a , b) (t_tez ())) then ok @@ t_tez () else fail (simple_error "Typing substraction, bad parameters.") @@ -220,7 +222,7 @@ module Typer = struct let%bind () = assert_type_value_eq (dst, v') in ok m - let map_mem : typer = typer_2 "MAP_MEM_TODO" @@ fun k m -> + let map_mem : typer = typer_2 "MAP_MEM" @@ fun k m -> let%bind (src, _dst) = get_t_map m in let%bind () = assert_type_value_eq (src, k) in ok @@ t_bool () @@ -235,46 +237,77 @@ module Typer = struct let%bind () = assert_type_value_eq (src, k) in ok @@ t_option dst () - let map_fold : typer = typer_3 "MAP_FOLD_TODO" @@ fun f m acc -> - let%bind (src, dst) = get_t_map m in - let expected_f_type = t_function (t_tuple [(t_tuple [src ; dst] ()) ; acc] ()) acc () in - let%bind () = assert_type_value_eq (f, expected_f_type) in - ok @@ acc - - let map_map : typer = typer_2 "MAP_MAP_TODO" @@ fun f m -> + let map_iter : typer = typer_2 "MAP_ITER" @@ fun f m -> let%bind (k, v) = get_t_map m in - let%bind (input_type, result_type) = get_t_function f in - let%bind () = assert_type_value_eq (input_type, t_tuple [k ; v] ()) in - ok @@ t_map k result_type () - - let map_map_fold : typer = typer_3 "MAP_MAP_TODO" @@ fun f m acc -> - let%bind (k, v) = get_t_map m in - let%bind (input_type, result_type) = get_t_function f in - let%bind () = assert_type_value_eq (input_type, t_tuple [t_tuple [k ; v] () ; acc] ()) in - let%bind ttuple = get_t_tuple result_type in - match ttuple with - | [result_acc ; result_dst ] -> - ok @@ t_tuple [ t_map k result_dst () ; result_acc ] () - (* TODO: error message *) - | _ -> fail @@ simple_error "function passed to map should take (k * v) * acc as an argument" - - let map_iter : typer = typer_2 "MAP_MAP_TODO" @@ fun f m -> - let%bind (k, v) = get_t_map m in - let%bind () = assert_type_value_eq (f, t_function (t_tuple [k ; v] ()) (t_unit ()) ()) in + let%bind (arg_1 , res) = get_t_function f in + let%bind (arg_2 , res') = get_t_function res in + let%bind () = assert_eq_1 arg_1 k in + let%bind () = assert_eq_1 arg_2 v in + let%bind () = assert_eq_1 res' (t_unit ()) in ok @@ t_unit () + let map_map : typer = typer_2 "MAP_MAP" @@ fun f m -> + let%bind (k, v) = get_t_map m in + let%bind (arg_1 , res) = get_t_function f in + let%bind (arg_2 , res') = get_t_function res in + let%bind () = assert_eq_1 arg_1 k in + let%bind () = assert_eq_1 arg_2 v in + ok @@ res' + + let map_fold : typer = typer_2 "MAP_FOLD" @@ fun f m -> + let%bind (k, v) = get_t_map m in + let%bind (arg_1 , res) = get_t_function f in + let%bind (arg_2 , res') = get_t_function res in + let%bind (arg_3 , res'') = get_t_function res' in + let%bind () = assert_eq_1 arg_1 k in + let%bind () = assert_eq_1 arg_2 v in + let%bind () = assert_eq_1 arg_3 res'' in + ok @@ res' + + let big_map_remove : typer = typer_2 "BIG_MAP_REMOVE" @@ fun k m -> + let%bind (src , _) = get_t_big_map m in + let%bind () = assert_type_value_eq (src , k) in + ok m + + let big_map_add : typer = typer_3 "BIG_MAP_ADD" @@ fun k v m -> + let%bind (src, dst) = get_t_big_map m in + let%bind () = assert_type_value_eq (src, k) in + let%bind () = assert_type_value_eq (dst, v) in + ok m + + let big_map_update : typer = typer_3 "BIG_MAP_UPDATE" @@ fun k v m -> + let%bind (src, dst) = get_t_big_map m in + let%bind () = assert_type_value_eq (src, k) in + let%bind v' = get_t_option v in + let%bind () = assert_type_value_eq (dst, v') in + ok m + + let big_map_mem : typer = typer_2 "BIG_MAP_MEM" @@ fun k m -> + let%bind (src, _dst) = get_t_big_map m in + let%bind () = assert_type_value_eq (src, k) in + ok @@ t_bool () + + let big_map_find : typer = typer_2 "BIG_MAP_FIND" @@ fun k m -> + let%bind (src, dst) = get_t_big_map m in + let%bind () = assert_type_value_eq (src, k) in + ok @@ dst + + let size = typer_1 "SIZE" @@ fun t -> let%bind () = Assert.assert_true @@ - (is_t_map t || is_t_list t || is_t_string t) in + (is_t_map t || is_t_list t || is_t_string t || is_t_bytes t || is_t_set t || is_t_big_map t) in ok @@ t_nat () let slice = typer_3 "SLICE" @@ fun i j s -> - let%bind () = - Assert.assert_true @@ - (is_t_nat i && is_t_nat j && is_t_string s) in - ok @@ t_string () - + let%bind () = assert_eq_1 i (t_nat ()) in + let%bind () = assert_eq_1 j (t_nat ()) in + if eq_1 s (t_string ()) + then ok @@ t_string () + else if eq_1 s (t_bytes ()) + then ok @@ t_bytes () + else simple_fail "bad slice" + let failwith_ = typer_1 "FAILWITH" @@ fun t -> let%bind () = Assert.assert_true @@ @@ -319,7 +352,7 @@ module Typer = struct let%bind () = assert_t_signature s in let%bind () = assert_t_bytes b in ok @@ t_bool () - + let sender = constant "SENDER" @@ t_address () let source = constant "SOURCE" @@ t_address () @@ -328,6 +361,8 @@ module Typer = struct let amount = constant "AMOUNT" @@ t_tez () + let balance = constant "BALANCE" @@ t_tez () + let address = constant "ADDRESS" @@ t_address () let now = constant "NOW" @@ t_timestamp () @@ -338,6 +373,19 @@ module Typer = struct let%bind () = assert_type_value_eq (param , contract_param) in ok @@ t_operation () + let originate = typer_6 "ORIGINATE" @@ fun manager delegate_opt spendable delegatable init_balance code -> + let%bind () = assert_eq_1 manager (t_key_hash ()) in + let%bind () = assert_eq_1 delegate_opt (t_option (t_key_hash ()) ()) in + let%bind () = assert_eq_1 spendable (t_bool ()) in + let%bind () = assert_eq_1 delegatable (t_bool ()) in + let%bind () = assert_t_tez init_balance in + let%bind (arg , res) = get_t_function code in + let%bind (_param , storage) = get_t_pair arg in + let%bind (storage' , op_lst) = get_t_pair res in + let%bind () = assert_eq_1 storage storage' in + let%bind () = assert_eq_1 op_lst (t_list (t_operation ()) ()) in + ok @@ (t_pair (t_operation ()) (t_address ()) ()) + let get_contract = typer_1_opt "CONTRACT" @@ fun _ tv_opt -> let%bind tv = trace_option (simple_error "get_contract needs a type annotation") tv_opt in @@ -346,15 +394,23 @@ module Typer = struct get_t_contract tv in ok @@ t_contract tv' () + let set_delegate = typer_1 "SET_DELEGATE" @@ fun delegate_opt -> + let%bind () = assert_eq_1 delegate_opt (t_option (t_key_hash ()) ()) in + ok @@ t_operation () + let abs = typer_1 "ABS" @@ fun t -> let%bind () = assert_t_int t in ok @@ t_nat () + let neg = typer_1 "NEG" @@ fun t -> + let%bind () = Assert.assert_true (eq_1 t (t_nat ()) || eq_1 t (t_int ())) in + ok @@ t_int () + let assertion = typer_1 "ASSERT" @@ fun a -> if eq_1 a (t_bool ()) then ok @@ t_unit () else simple_fail "Asserting a non-bool" - + let times = typer_2 "TIMES" @@ fun a b -> if eq_2 (a , b) (t_nat ()) then ok @@ t_nat () else @@ -387,6 +443,8 @@ module Typer = struct then ok @@ t_tez () else if (eq_1 a (t_nat ()) && eq_1 b (t_int ())) || (eq_1 b (t_nat ()) && eq_1 a (t_int ())) then ok @@ t_int () else + if (eq_1 a (t_timestamp ()) && eq_1 b (t_int ())) || (eq_1 b (t_timestamp ()) && eq_1 a (t_int ())) + then ok @@ t_timestamp () else simple_fail "Adding with wrong types. Expected nat, int or tez." let set_mem = typer_2 "SET_MEM" @@ fun elt set -> @@ -407,11 +465,79 @@ module Typer = struct then ok set else simple_fail "Set_remove: elt and set don't match" + let set_iter = typer_2 "SET_ITER" @@ fun set body -> + let%bind (arg , res) = get_t_function body in + let%bind () = Assert.assert_true (eq_1 res (t_unit ())) in + let%bind key = get_t_set set in + if eq_1 key arg + then ok (t_unit ()) + else simple_fail "bad set iter" + + let list_iter = typer_2 "LIST_ITER" @@ fun lst body -> + let%bind (arg , res) = get_t_function body in + let%bind () = Assert.assert_true (eq_1 res (t_unit ())) in + let%bind key = get_t_list lst in + if eq_1 key arg + then ok (t_unit ()) + else simple_fail "bad list iter" + + let list_map = typer_2 "LIST_MAP" @@ fun lst body -> + let%bind (arg , res) = get_t_function body in + let%bind key = get_t_list lst in + if eq_1 key arg + then ok res + else simple_fail "bad list iter" + let not_ = typer_1 "NOT" @@ fun elt -> if eq_1 elt (t_bool ()) then ok @@ t_bool () + else if eq_1 elt (t_nat ()) || eq_1 elt (t_int ()) + then ok @@ t_int () else simple_fail "bad parameter to not" - + + let or_ = typer_2 "OR" @@ fun a b -> + if eq_2 (a , b) (t_bool ()) + then ok @@ t_bool () + else if eq_2 (a , b) (t_nat ()) + then ok @@ t_nat () + else simple_fail "bad or" + + let xor = typer_2 "XOR" @@ fun a b -> + if eq_2 (a , b) (t_bool ()) + then ok @@ t_bool () + else if eq_2 (a , b) (t_nat ()) + then ok @@ t_nat () + else simple_fail "bad xor" + + let and_ = typer_2 "AND" @@ fun a b -> + if eq_2 (a , b) (t_bool ()) + then ok @@ t_bool () + else if eq_2 (a , b) (t_nat ()) || (eq_1 b (t_nat ()) && eq_1 a (t_int ())) + then ok @@ t_nat () + else simple_fail "bad end" + + let lsl_ = typer_2 "LSL" @@ fun a b -> + if eq_2 (a , b) (t_nat ()) + then ok @@ t_nat () + else simple_fail "bad lsl" + + let lsr_ = typer_2 "LSR" @@ fun a b -> + if eq_2 (a , b) (t_nat ()) + then ok @@ t_nat () + else simple_fail "bad lsr" + + let concat = typer_2 "CONCAT" @@ fun a b -> + if eq_2 (a , b) (t_string ()) + then ok @@ t_string () + else if eq_2 (a , b) (t_bytes ()) + then ok @@ t_bytes () + else simple_fail "bad concat" + + let cons = typer_2 "CONS" @@ fun hd tl -> + let%bind elt = get_t_list tl in + let%bind () = assert_eq_1 hd elt in + ok tl + let constant_typers = Map.String.of_list [ add ; times ; @@ -428,20 +554,19 @@ module Typer = struct comparator "GE" ; boolean_operator_2 "OR" ; boolean_operator_2 "AND" ; + boolean_operator_2 "XOR" ; not_ ; map_remove ; map_add ; map_update ; map_mem ; map_find ; - map_map_fold ; map_map ; map_fold ; map_iter ; set_mem ; set_add ; set_remove ; - (* map_size ; (* use size *) *) int ; size ; failwith_ ; @@ -459,6 +584,7 @@ module Typer = struct amount ; transaction ; get_contract ; + neg ; abs ; now ; slice ; @@ -539,5 +665,5 @@ module Compiler = struct ] (* Some complex predicates will need to be added in compiler/compiler_program *) - + end diff --git a/src/simplify/pascaligo.ml b/src/simplify/pascaligo.ml index 2f3299cc3..6542473d4 100644 --- a/src/simplify/pascaligo.ml +++ b/src/simplify/pascaligo.ml @@ -468,12 +468,11 @@ let rec simpl_expression (t:Raw.expr) : expr result = return @@ e_literal ~loc (Literal_nat n) ) | EArith (Mtz n) -> ( - let (n , loc) = r_split n in - let n = Z.to_int @@ snd @@ n in - return @@ e_literal ~loc (Literal_tez n) - ) - | EArith _ as e -> - fail @@ unsupported_arith_op e + let (n , loc) = r_split n in + let n = Z.to_int @@ snd @@ n in + return @@ e_literal ~loc (Literal_tez n) + ) + | EArith (Neg e) -> simpl_unop "NEG" e | EString (String s) -> let (s , loc) = r_split s in let s' = diff --git a/src/test/integration_tests.ml b/src/test/integration_tests.ml index 55445db99..2f4af5adb 100644 --- a/src/test/integration_tests.ml +++ b/src/test/integration_tests.ml @@ -127,7 +127,6 @@ let arithmetic () : unit result = ("plus_op", fun n -> (n + 42)) ; ("minus_op", fun n -> (n - 42)) ; ("times_op", fun n -> (n * 42)) ; - (* ("div_op", fun n -> (n / 2)) ; *) ] in let%bind () = expect_eq_n_pos program "int_op" e_nat e_int in let%bind () = expect_eq_n_pos program "mod_op" e_int (fun n -> e_nat (n mod 42)) in