diff --git a/src/ligo/contracts/super-counter.ligo b/src/ligo/contracts/super-counter.ligo index c8f053d18..beb73a5a9 100644 --- a/src/ligo/contracts/super-counter.ligo +++ b/src/ligo/contracts/super-counter.ligo @@ -1,10 +1,10 @@ -type action = +type action is | Increment of int | Decrement of int function main (const p : action ; const s : int) : (list(operation) * int) is block {skip} with ((nil : operation), - match p with + case p of | Increment n -> s + n | Decrement n -> s - n end) diff --git a/src/ligo/main/contract.ml b/src/ligo/main/contract.ml index 662b828f5..ac35890d1 100644 --- a/src/ligo/main/contract.ml +++ b/src/ligo/main/contract.ml @@ -19,11 +19,11 @@ include struct open Ast_typed open Combinators - let assert_entry_point_type : type_value -> unit result = fun t -> + let get_entry_point_type : type_value -> (type_value * type_value) result = fun t -> let%bind (arg , result) = trace_strong (simple_error "entry-point doesn't have a function type") @@ get_t_function t in - let%bind (_ , storage_param) = + let%bind (arg' , storage_param) = trace_strong (simple_error "entry-point doesn't have 2 parameters") @@ get_t_pair arg in let%bind (ops , storage_result) = @@ -35,12 +35,16 @@ include struct let%bind () = trace_strong (simple_error "entry-point doesn't identitcal type (storage) for second parameter and second result") @@ assert_type_value_eq (storage_param , storage_result) in - ok () + ok (arg' , storage_param) - let assert_valid_entry_point : program -> string -> unit result = fun p e -> + let get_entry_point : program -> string -> (type_value * type_value) result = fun p e -> let%bind declaration = get_declaration_by_name p e in match declaration with - | Declaration_constant (d , _) -> assert_entry_point_type d.annotated_expression.type_annotation + | Declaration_constant (d , _) -> get_entry_point_type d.annotated_expression.type_annotation + + let assert_valid_entry_point = fun p e -> + let%bind _ = get_entry_point p e in + ok () end let transpile_value @@ -81,7 +85,7 @@ let compile_contract_file : string -> string -> string result = fun source entry ok str let compile_contract_parameter : string -> string -> string -> string result = fun source entry_point expression -> - let%bind parameter_tv = + let%bind (program , parameter_tv) = let%bind raw = trace (simple_error "parsing file") @@ Parser.parse_file source in @@ -93,11 +97,9 @@ let compile_contract_parameter : string -> string -> string -> string result = f let%bind typed = trace (simple_error "typing file") @@ Typer.type_program simplified in - let%bind () = - assert_valid_entry_point typed entry_point in - let%bind declaration = Ast_typed.Combinators.get_declaration_by_name typed entry_point in - match declaration with - | Declaration_constant (d , _) -> ok d.annotated_expression.type_annotation + let%bind (param_ty , _) = + get_entry_point typed entry_point in + ok (typed , param_ty) in let%bind expr = let%bind raw = @@ -107,8 +109,13 @@ let compile_contract_parameter : string -> string -> string -> string result = f trace (simple_error "simplifying expression") @@ Simplify.Pascaligo.simpl_expression raw in let%bind typed = + let env = + let last_declaration = Location.unwrap List.(hd @@ rev program) in + match last_declaration with + | Declaration_constant (_ , env) -> env + in trace (simple_error "typing expression") @@ - Typer.type_annotated_expression Ast_typed.Environment.full_empty simplified in + Typer.type_annotated_expression env simplified in let%bind () = trace (simple_error "expression type doesn't match type parameter") @@ Ast_typed.assert_type_value_eq (parameter_tv , typed.type_annotation) in diff --git a/src/ligo/test/integration_tests.ml b/src/ligo/test/integration_tests.ml index d29b2fc7f..a393e4510 100644 --- a/src/ligo/test/integration_tests.ml +++ b/src/ligo/test/integration_tests.ml @@ -325,6 +325,16 @@ let counter_contract () : unit result = let make_expected = fun n -> e_a_pair (e_a_list [] t_operation) (e_a_int (42 + n)) in expect_n program "main" make_input make_expected +let super_counter_contract () : unit result = + let%bind program = type_file "./contracts/super-counter.ligo" in + let make_input = fun n -> + let action = if n mod 2 = 0 then "Increment" else "Decrement" in + e_a_pair (e_a_constructor action (e_a_int n)) (e_a_int 42) in + let make_expected = fun n -> + let op = if n mod 2 = 0 then (+) else (-) in + e_a_pair (e_a_list [] t_operation) (e_a_int (op 42 n)) in + expect_n program "main" make_input make_expected + let main = "Integration (End to End)", [ test "function" function_ ; test "complex function" complex_function ; @@ -350,5 +360,6 @@ let main = "Integration (End to End)", [ test "quote declarations" quote_declarations ; test "#include directives" include_ ; test "counter contract" counter_contract ; + test "super counter contract" super_counter_contract ; test "higher order" higher_order ; ] diff --git a/src/ligo/transpiler/transpiler.ml b/src/ligo/transpiler/transpiler.ml index 008e467c5..89467701c 100644 --- a/src/ligo/transpiler/transpiler.ml +++ b/src/ligo/transpiler/transpiler.ml @@ -386,7 +386,7 @@ and translate_annotated_expression (env:Environment.t) (ae:AST.annotated_express List.find_opt (fun ((constructor_name' , _) , _) -> constructor_name' = constructor_name) lst in let env' = Environment.(add (name , tv) @@ extend env) in let%bind body' = translate_annotated_expression env' body in - return ~env:env' @@ E_let_in ((name , tv) , top , body') + return ~env @@ E_let_in ((name , tv) , top , body') ) | ((`Node (a , b)) , tv) -> let%bind a' =