diff --git a/src/passes/8-compiler/compiler_program.ml b/src/passes/8-compiler/compiler_program.ml index c1f7cc5b6..e4e91f921 100644 --- a/src/passes/8-compiler/compiler_program.ml +++ b/src/passes/8-compiler/compiler_program.ml @@ -80,6 +80,12 @@ let get_operator : constant -> type_value -> expression list -> predicate result prim ~children:[r_ty] I_CONTRACT ; i_assert_some_msg (i_push_string "bad address for get_contract") ; ] + | C_CONTRACT_OPT -> + let%bind tc = get_t_option ty in + let%bind r = get_t_contract tc in + let%bind r_ty = Compiler_type.type_ r in + ok @@ simple_unary @@ prim ~children:[r_ty] I_CONTRACT ; + | C_CONTRACT_ENTRYPOINT -> let%bind r = get_t_contract ty in let%bind r_ty = Compiler_type.type_ r in @@ -94,6 +100,20 @@ let get_operator : constant -> type_value -> expression list -> predicate result prim ~annot:[entry] ~children:[r_ty] I_CONTRACT ; i_assert_some_msg (i_push_string @@ Format.sprintf "bad address for get_entrypoint (%s)" entry) ; ] + | C_CONTRACT_ENTRYPOINT_OPT -> + let%bind tc = get_t_option ty in + let%bind r = get_t_contract tc in + let%bind r_ty = Compiler_type.type_ r in + let%bind entry = match lst with + | [ { content = E_literal (D_string entry); type_value = _ } ; _addr ] -> ok entry + | [ _entry ; _addr ] -> + fail @@ contract_entrypoint_must_be_literal ~loc:__LOC__ + | _ -> + fail @@ corner_case ~loc:__LOC__ "mini_c . CONTRACT_ENTRYPOINT" in + ok @@ simple_binary @@ seq [ + i_drop ; (* drop the entrypoint... *) + prim ~annot:[entry] ~children:[r_ty] I_CONTRACT ; + ] | x -> simple_fail (Format.asprintf "predicate \"%a\" doesn't exist" Stage_common.PP.constant x) ) diff --git a/src/passes/operators/operators.ml b/src/passes/operators/operators.ml index 3c2c91910..49f693030 100644 --- a/src/passes/operators/operators.ml +++ b/src/passes/operators/operators.ml @@ -70,7 +70,9 @@ module Simplify = struct | "get_chain_id" -> ok C_CHAIN_ID | "transaction" -> ok C_CALL | "get_contract" -> ok C_CONTRACT + | "get_contract_opt"-> ok C_CONTRACT_OPT | "get_entrypoint" -> ok C_CONTRACT_ENTRYPOINT + | "get_entrypoint_opt" -> ok C_CONTRACT_ENTRYPOINT_OPT | "size" -> ok C_SIZE | "int" -> ok C_INT | "abs" -> ok C_ABS @@ -228,7 +230,9 @@ module Simplify = struct | "Operation.transaction" -> ok C_CALL | "Operation.set_delegate" -> ok C_SET_DELEGATE | "Operation.get_contract" -> ok C_CONTRACT + | "Operation.get_contract_opt" -> ok C_CONTRACT_OPT | "Operation.get_entrypoint" -> ok C_CONTRACT_ENTRYPOINT + | "Operation.get_entrypoint_opt" -> ok C_CONTRACT_ENTRYPOINT_OPT | "int" -> ok C_INT | "abs" -> ok C_ABS | "unit" -> ok C_UNIT @@ -657,6 +661,20 @@ module Typer = struct get_t_contract tv in ok @@ t_contract tv' () + let get_contract_opt = typer_1_opt "CONTRACT OPT" @@ fun addr_tv tv_opt -> + if not (type_value_eq (addr_tv, t_address ())) + then fail @@ simple_error (Format.asprintf "get_contract_opt expects an address, got %a" PP.type_value addr_tv) + else + let%bind tv = + trace_option (simple_error "get_contract_opt needs a type annotation") tv_opt in + let%bind tv = + trace_strong (simple_error "get_entrypoint_opt has a not-option annotation") @@ + get_t_option tv in + let%bind tv' = + trace_strong (simple_error "get_entrypoint_opt has a not-option(contract) annotation") @@ + get_t_contract tv in + ok @@ t_option (t_contract tv' ()) () + let get_entrypoint = typer_2_opt "CONTRACT_ENTRYPOINT" @@ fun entry_tv addr_tv tv_opt -> if not (type_value_eq (entry_tv, t_string ())) then fail @@ simple_error (Format.asprintf "get_entrypoint expects a string entrypoint label for first argument, got %a" PP.type_value entry_tv) @@ -671,6 +689,23 @@ module Typer = struct get_t_contract tv in ok @@ t_contract tv' () + let get_entrypoint_opt = typer_2_opt "CONTRACT_ENTRYPOINT_OPT" @@ fun entry_tv addr_tv tv_opt -> + if not (type_value_eq (entry_tv, t_string ())) + then fail @@ simple_error (Format.asprintf "get_entrypoint_opt expects a string entrypoint label for first argument, got %a" PP.type_value entry_tv) + else + if not (type_value_eq (addr_tv, t_address ())) + then fail @@ simple_error (Format.asprintf "get_entrypoint_opt expects an address for second argument, got %a" PP.type_value addr_tv) + else + let%bind tv = + trace_option (simple_error "get_entrypoint_opt needs a type annotation") tv_opt in + let%bind tv = + trace_strong (simple_error "get_entrypoint_opt has a not-option annotation") @@ + get_t_option tv in + let%bind tv' = + trace_strong (simple_error "get_entrypoint_opt has a not-option(contract) annotation") @@ + get_t_contract tv in + ok @@ t_option (t_contract tv' ())() + let set_delegate = typer_1 "SET_DELEGATE" @@ fun delegate_opt -> let%bind () = assert_eq_1 delegate_opt (t_option (t_key_hash ()) ()) in ok @@ t_operation () @@ -1020,7 +1055,9 @@ module Typer = struct | C_CHAIN_ID -> ok @@ chain_id ; (*BLOCKCHAIN *) | C_CONTRACT -> ok @@ get_contract ; + | C_CONTRACT_OPT -> ok @@ get_contract_opt ; | C_CONTRACT_ENTRYPOINT -> ok @@ get_entrypoint ; + | C_CONTRACT_ENTRYPOINT_OPT -> ok @@ get_entrypoint_opt ; | C_AMOUNT -> ok @@ amount ; | C_BALANCE -> ok @@ balance ; | C_CALL -> ok @@ transaction ; diff --git a/src/stages/common/PP.ml b/src/stages/common/PP.ml index deebe08ee..773b5eaab 100644 --- a/src/stages/common/PP.ml +++ b/src/stages/common/PP.ml @@ -108,7 +108,9 @@ let constant ppf : constant -> unit = function (* Blockchain *) | C_CALL -> fprintf ppf "CALL" | C_CONTRACT -> fprintf ppf "CONTRACT" + | C_CONTRACT_OPT -> fprintf ppf "CONTRACT_OPT" | C_CONTRACT_ENTRYPOINT -> fprintf ppf "CONTRACT_ENTRYPOINT" + | C_CONTRACT_ENTRYPOINT_OPT -> fprintf ppf "CONTRACT_ENTRYPOINT_OPT" | C_AMOUNT -> fprintf ppf "AMOUNT" | C_BALANCE -> fprintf ppf "BALANCE" | C_SOURCE -> fprintf ppf "SOURCE" diff --git a/src/stages/common/types.ml b/src/stages/common/types.ml index 70c3bc80a..a0c6f9cb6 100644 --- a/src/stages/common/types.ml +++ b/src/stages/common/types.ml @@ -225,7 +225,9 @@ type constant = (* Blockchain *) | C_CALL | C_CONTRACT + | C_CONTRACT_OPT | C_CONTRACT_ENTRYPOINT + | C_CONTRACT_ENTRYPOINT_OPT | C_AMOUNT | C_BALANCE | C_SOURCE diff --git a/src/test/contracts/entrypoints.ligo b/src/test/contracts/entrypoints.ligo index 1d49a468c..d884a1ec9 100644 --- a/src/test/contracts/entrypoints.ligo +++ b/src/test/contracts/entrypoints.ligo @@ -3,3 +3,14 @@ function cb(const a : address; const s : unit) : list(operation) * unit is const c : contract(unit) = get_entrypoint("%cb", a) } with (list transaction(unit, 0mutez, c) end, s) + + +function cbo(const a : address; const s : unit) : list(operation) * unit is + block { + const c : contract(unit) = + case (get_entrypoint_opt("%cbo", a) : option(contract (unit))) of + | Some (c) -> c + | None -> (failwith ("entrypoint not found") : contract (unit)) + end + } + with (list transaction(unit, 0mutez, c) end, s) diff --git a/src/test/contracts/get_contract.ligo b/src/test/contracts/get_contract.ligo new file mode 100644 index 000000000..12d58aba6 --- /dev/null +++ b/src/test/contracts/get_contract.ligo @@ -0,0 +1,16 @@ +function cb(const s : unit) : list(operation) * unit is + block { + const c : contract(unit) = get_contract(source) + } + with (list transaction(unit, 0mutez, c) end, s) + + +function cbo(const s : unit) : list(operation) * unit is + block { + const c : contract(unit) = + case (get_contract_opt(source) : option(contract (unit))) of + | Some (c) -> c + | None -> (failwith ("contract not found") : contract (unit)) + end + } + with (list transaction(unit, 0mutez, c) end, s) diff --git a/src/test/integration_tests.ml b/src/test/integration_tests.ml index 97e09d2c2..204d2fc55 100644 --- a/src/test/integration_tests.ml +++ b/src/test/integration_tests.ml @@ -1916,6 +1916,26 @@ let attributes_religo () : unit result = in ok () +let get_contract_ligo () : unit result = + let%bind program = type_file "./contracts/get_contract.ligo" in + let%bind () = + let make_input = fun _n -> e_unit () in + let make_expected : int -> Ast_simplified.expression -> unit result = fun _n result -> + let%bind (ops , storage) = get_e_pair result.expression in + let%bind () = + let%bind lst = get_e_list ops.expression in + Assert.assert_list_size lst 1 in + let expected_storage = e_unit () in + Ast_simplified.Misc.assert_value_eq (expected_storage , storage) + in + let%bind () = + let amount = Memory_proto_alpha.Protocol.Alpha_context.Tez.zero in + let options = Proto_alpha_utils.Memory_proto_alpha.make_options ~amount () in + let%bind () = expect_n_strict_pos_small ~options program "cb" make_input make_expected in + expect_n_strict_pos_small ~options program "cbo" make_input make_expected in + ok () + in + ok() let entrypoints_ligo () : unit result = let%bind _program = type_file "./contracts/entrypoints.ligo" in @@ -2337,6 +2357,7 @@ let main = test_suite "Integration (End to End)" [ test "tuples_sequences_functions (religo)" tuples_sequences_functions_religo ; test "simple_access (ligo)" simple_access_ligo; test "deep_access (ligo)" deep_access_ligo; + test "get_contract (ligo)" get_contract_ligo; test "entrypoints (ligo)" entrypoints_ligo ; test "curry (mligo)" curry ; test "type tuple destruct (mligo)" type_tuple_destruct ;