add list fold
This commit is contained in:
parent
f3c80908ee
commit
b41b676eb8
@ -361,47 +361,52 @@ and transpile_annotated_expression (ae:AST.annotated_expression) : expression re
|
||||
let expr = List.fold_left aux record' path in
|
||||
ok expr
|
||||
| E_constant (name , lst) -> (
|
||||
let (iter , map) =
|
||||
let iterator name = fun (lst : AST.annotated_expression list) -> match lst with
|
||||
| [i ; f] -> (
|
||||
let%bind f' = match f.expression with
|
||||
| E_lambda l -> (
|
||||
let%bind body' = transpile_annotated_expression l.body in
|
||||
let%bind (input , _) = AST.get_t_function f.type_annotation in
|
||||
let%bind input' = transpile_type input in
|
||||
ok ((l.binder , input') , body')
|
||||
)
|
||||
| E_variable v -> (
|
||||
let%bind elt =
|
||||
trace_option (corner_case ~loc:__LOC__ "missing var") @@
|
||||
AST.Environment.get_opt v f.environment in
|
||||
match elt.definition with
|
||||
| ED_declaration (f , _) -> (
|
||||
match f.expression with
|
||||
| E_lambda l -> (
|
||||
let%bind body' = transpile_annotated_expression l.body in
|
||||
let%bind (input , _) = AST.get_t_function f.type_annotation in
|
||||
let%bind input' = transpile_type input in
|
||||
ok ((l.binder , input') , body')
|
||||
)
|
||||
| _ -> fail @@ unsupported_iterator f.location
|
||||
)
|
||||
| _ -> fail @@ unsupported_iterator f.location
|
||||
)
|
||||
| _ -> fail @@ unsupported_iterator f.location
|
||||
in
|
||||
let%bind i' = transpile_annotated_expression i in
|
||||
return @@ E_iterator (name , f' , i')
|
||||
)
|
||||
| _ -> fail @@ corner_case ~loc:__LOC__ "bad iterator arity"
|
||||
let iterator_generator iterator_name =
|
||||
let lambda_to_iterator_body (f : AST.annotated_expression) (l : AST.lambda) =
|
||||
let%bind body' = transpile_annotated_expression l.body in
|
||||
let%bind (input , _) = AST.get_t_function f.type_annotation in
|
||||
let%bind input' = transpile_type input in
|
||||
ok ((l.binder , input') , body')
|
||||
in
|
||||
iterator "ITER" , iterator "MAP" in
|
||||
let expression_to_iterator_body (f : AST.annotated_expression) =
|
||||
match f.expression with
|
||||
| E_lambda l -> lambda_to_iterator_body f l
|
||||
| E_variable v -> (
|
||||
let%bind elt =
|
||||
trace_option (corner_case ~loc:__LOC__ "missing var") @@
|
||||
AST.Environment.get_opt v f.environment in
|
||||
match elt.definition with
|
||||
| ED_declaration (f , _) -> (
|
||||
match f.expression with
|
||||
| E_lambda l -> lambda_to_iterator_body f l
|
||||
| _ -> fail @@ unsupported_iterator f.location
|
||||
)
|
||||
| _ -> fail @@ unsupported_iterator f.location
|
||||
)
|
||||
| _ -> fail @@ unsupported_iterator f.location
|
||||
in
|
||||
fun (lst : AST.annotated_expression list) -> match (lst , iterator_name) with
|
||||
| [i ; f] , "ITER" | [i ; f] , "MAP" -> (
|
||||
let%bind f' = expression_to_iterator_body f in
|
||||
let%bind i' = transpile_annotated_expression i in
|
||||
return @@ E_iterator (iterator_name , f' , i')
|
||||
)
|
||||
| [ collection ; initial ; f ] , "FOLD" -> (
|
||||
let%bind f' = expression_to_iterator_body f in
|
||||
let%bind initial' = transpile_annotated_expression initial in
|
||||
let%bind collection' = transpile_annotated_expression collection in
|
||||
return @@ E_fold (f' , collection' , initial')
|
||||
)
|
||||
| _ -> fail @@ corner_case ~loc:__LOC__ ("bad iterator arity:" ^ iterator_name)
|
||||
in
|
||||
let (iter , map , fold) = iterator_generator "ITER" , iterator_generator "MAP" , iterator_generator "FOLD" in
|
||||
match (name , lst) with
|
||||
| ("SET_ITER" , lst) -> iter lst
|
||||
| ("LIST_ITER" , lst) -> iter lst
|
||||
| ("MAP_ITER" , lst) -> iter lst
|
||||
| ("LIST_MAP" , lst) -> map lst
|
||||
| ("MAP_MAP" , lst) -> map lst
|
||||
| ("LIST_FOLD" , lst) -> fold lst
|
||||
| _ -> (
|
||||
let%bind lst' = bind_map_list (transpile_annotated_expression) lst in
|
||||
return @@ E_constant (name , lst')
|
||||
|
@ -339,6 +339,20 @@ and translate_expression (expr:expression) (env:environment) : michelson result
|
||||
fail error
|
||||
)
|
||||
)
|
||||
| E_fold ((v , body) , collection , initial) -> (
|
||||
let%bind collection' = translate_expression collection env in
|
||||
let%bind initial' = translate_expression initial env in
|
||||
let%bind body' = translate_expression body (Environment.add v env) in
|
||||
let code = seq [
|
||||
collection' ;
|
||||
dip initial' ;
|
||||
i_iter (seq [
|
||||
i_swap ;
|
||||
i_pair ; body' ; dip i_drop ;
|
||||
]) ;
|
||||
] in
|
||||
ok code
|
||||
)
|
||||
| E_assignment (name , lrs , expr) -> (
|
||||
let%bind expr' = translate_expression expr env in
|
||||
let%bind get_code = Compiler_environment.get env name in
|
||||
|
@ -104,7 +104,7 @@ module Typer = struct
|
||||
let eq_1 a cst = type_value_eq (a , cst)
|
||||
let eq_2 (a , b) cst = type_value_eq (a , cst) && type_value_eq (b , cst)
|
||||
|
||||
let assert_eq_1 a b = Assert.assert_true (eq_1 a b)
|
||||
let assert_eq_1 ?msg a b = Assert.assert_true ?msg (eq_1 a b)
|
||||
|
||||
let comparator : string -> typer = fun s -> typer_2 s @@ fun a b ->
|
||||
let%bind () =
|
||||
|
@ -82,6 +82,7 @@ module Simplify = struct
|
||||
("set_remove" , "SET_REMOVE") ;
|
||||
("set_iter" , "SET_ITER") ;
|
||||
("list_iter" , "LIST_ITER") ;
|
||||
("list_fold" , "LIST_FOLD") ;
|
||||
("list_map" , "LIST_MAP") ;
|
||||
("map_iter" , "MAP_ITER") ;
|
||||
("map_map" , "MAP_MAP") ;
|
||||
@ -152,6 +153,8 @@ module Simplify = struct
|
||||
("Map.update" , "MAP_UPDATE") ;
|
||||
("Map.add" , "MAP_ADD") ;
|
||||
("Map.remove" , "MAP_REMOVE") ;
|
||||
("Map.iter" , "MAP_ITER") ;
|
||||
("Map.map" , "MAP_MAP") ;
|
||||
|
||||
("String.length", "SIZE") ;
|
||||
("String.size", "SIZE") ;
|
||||
@ -161,7 +164,9 @@ module Simplify = struct
|
||||
|
||||
("List.length", "SIZE") ;
|
||||
("List.size", "SIZE") ;
|
||||
("List.iter", "ITER") ;
|
||||
("List.iter", "LIST_ITER") ;
|
||||
("List.map" , "LIST_MAP") ;
|
||||
("List.fold" , "LIST_FOLD") ;
|
||||
|
||||
("Operation.transaction" , "CALL") ;
|
||||
("Operation.get_contract" , "CONTRACT") ;
|
||||
@ -483,7 +488,21 @@ module Typer = struct
|
||||
let%bind key = get_t_list lst in
|
||||
if eq_1 key arg
|
||||
then ok (t_list res ())
|
||||
else simple_fail "bad list iter"
|
||||
else simple_fail "bad list map"
|
||||
|
||||
let list_fold = typer_3 "LIST_FOLD" @@ fun lst init body ->
|
||||
let%bind (arg , res) = get_t_function body in
|
||||
let%bind (prec , cur) = get_t_pair arg in
|
||||
let%bind key = get_t_list lst in
|
||||
let msg = Format.asprintf "%a vs %a"
|
||||
Ast_typed.PP.type_value key
|
||||
Ast_typed.PP.type_value arg
|
||||
in
|
||||
trace (simple_error ("bad list fold:" ^ msg)) @@
|
||||
let%bind () = assert_eq_1 ~msg:"key cur" key cur in
|
||||
let%bind () = assert_eq_1 ~msg:"prec res" prec res in
|
||||
let%bind () = assert_eq_1 ~msg:"res init" res init in
|
||||
ok res
|
||||
|
||||
let not_ = typer_1 "NOT" @@ fun elt ->
|
||||
if eq_1 elt (t_bool ())
|
||||
@ -570,6 +589,7 @@ module Typer = struct
|
||||
set_iter ;
|
||||
list_iter ;
|
||||
list_map ;
|
||||
list_fold ;
|
||||
int ;
|
||||
size ;
|
||||
failwith_ ;
|
||||
|
@ -90,6 +90,8 @@ and expression' ppf (e:expression') = match e with
|
||||
fprintf ppf "let %s = %a in ( %a )" name expression expr expression body
|
||||
| E_iterator (s , ((name , _) , body) , expr) ->
|
||||
fprintf ppf "for_%s %s of %a do ( %a )" s name expression expr expression body
|
||||
| E_fold (((name , _) , body) , collection , initial) ->
|
||||
fprintf ppf "fold %a on %a with %s do ( %a )" expression collection expression initial name expression body
|
||||
| E_assignment (r , path , e) ->
|
||||
fprintf ppf "%s.%a := %a" r (list_sep lr (const ".")) path expression e
|
||||
| E_while (e , b) ->
|
||||
|
@ -69,6 +69,7 @@ and expression' =
|
||||
| E_make_empty_set of type_value
|
||||
| E_make_none of type_value
|
||||
| E_iterator of (string * ((var_name * type_value) * expression) * expression)
|
||||
| E_fold of (((var_name * type_value) * expression) * expression * expression)
|
||||
| E_if_bool of expression * expression * expression
|
||||
| E_if_none of expression * expression * ((var_name * type_value) * expression)
|
||||
| E_if_cons of (expression * expression * (((var_name * type_value) * (var_name * type_value)) * expression))
|
||||
|
@ -12,3 +12,7 @@ let%entry main (p : param) storage =
|
||||
[] -> storage
|
||||
| hd::tl -> storage.(0) + hd, tl
|
||||
in (([] : operation list), storage)
|
||||
|
||||
let fold_op (s : int list) : int =
|
||||
let aggregate = fun (prec : int) (cur : int) -> prec + cur in
|
||||
List.fold s 10 aggregate
|
||||
|
@ -674,19 +674,21 @@ let match_matej () : unit result =
|
||||
|
||||
let mligo_list () : unit result =
|
||||
let%bind program = mtype_file "./contracts/list.mligo" in
|
||||
let%bind () =
|
||||
let make_input n =
|
||||
e_pair (e_list [e_int n; e_int (2*n)])
|
||||
(e_pair (e_int 3) (e_list [e_int 8])) in
|
||||
let make_expected n =
|
||||
e_pair (e_typed_list [] t_operation)
|
||||
(e_pair (e_int (n+3)) (e_list [e_int (2*n)]))
|
||||
in
|
||||
expect_eq_n program "main" make_input make_expected
|
||||
in
|
||||
let%bind () = expect_eq_evaluate program "x" (e_list []) in
|
||||
let%bind () = expect_eq_evaluate program "y" (e_list @@ List.map e_int [3 ; 4 ; 5]) in
|
||||
let%bind () = expect_eq_evaluate program "z" (e_list @@ List.map e_int [2 ; 3 ; 4 ; 5]) in
|
||||
let aux lst = e_list @@ List.map e_int lst in
|
||||
let%bind () = expect_eq program "fold_op" (aux [ 1 ; 2 ; 3 ]) (e_int 16) in
|
||||
(* let%bind () =
|
||||
* let make_input n =
|
||||
* e_pair (e_list [e_int n; e_int (2*n)])
|
||||
* (e_pair (e_int 3) (e_list [e_int 8])) in
|
||||
* let make_expected n =
|
||||
* e_pair (e_typed_list [] t_operation)
|
||||
* (e_pair (e_int (n+3)) (e_list [e_int (2*n)]))
|
||||
* in
|
||||
* expect_eq_n program "main" make_input make_expected
|
||||
* in
|
||||
* let%bind () = expect_eq_evaluate program "x" (e_list []) in
|
||||
* let%bind () = expect_eq_evaluate program "y" (e_list @@ List.map e_int [3 ; 4 ; 5]) in
|
||||
* let%bind () = expect_eq_evaluate program "z" (e_list @@ List.map e_int [2 ; 3 ; 4 ; 5]) in *)
|
||||
ok ()
|
||||
|
||||
let lambda_mligo () : unit result =
|
||||
|
Loading…
Reference in New Issue
Block a user