add list fold

This commit is contained in:
galfour 2019-09-23 23:33:25 +02:00
parent f3c80908ee
commit b41b676eb8
8 changed files with 98 additions and 50 deletions

View File

@ -361,16 +361,16 @@ and transpile_annotated_expression (ae:AST.annotated_expression) : expression re
let expr = List.fold_left aux record' path in
ok expr
| E_constant (name , lst) -> (
let (iter , map) =
let iterator name = fun (lst : AST.annotated_expression list) -> match lst with
| [i ; f] -> (
let%bind f' = match f.expression with
| E_lambda l -> (
let iterator_generator iterator_name =
let lambda_to_iterator_body (f : AST.annotated_expression) (l : AST.lambda) =
let%bind body' = transpile_annotated_expression l.body in
let%bind (input , _) = AST.get_t_function f.type_annotation in
let%bind input' = transpile_type input in
ok ((l.binder , input') , body')
)
in
let expression_to_iterator_body (f : AST.annotated_expression) =
match f.expression with
| E_lambda l -> lambda_to_iterator_body f l
| E_variable v -> (
let%bind elt =
trace_option (corner_case ~loc:__LOC__ "missing var") @@
@ -378,30 +378,35 @@ and transpile_annotated_expression (ae:AST.annotated_expression) : expression re
match elt.definition with
| ED_declaration (f , _) -> (
match f.expression with
| E_lambda l -> (
let%bind body' = transpile_annotated_expression l.body in
let%bind (input , _) = AST.get_t_function f.type_annotation in
let%bind input' = transpile_type input in
ok ((l.binder , input') , body')
)
| E_lambda l -> lambda_to_iterator_body f l
| _ -> fail @@ unsupported_iterator f.location
)
| _ -> fail @@ unsupported_iterator f.location
)
| _ -> fail @@ unsupported_iterator f.location
in
fun (lst : AST.annotated_expression list) -> match (lst , iterator_name) with
| [i ; f] , "ITER" | [i ; f] , "MAP" -> (
let%bind f' = expression_to_iterator_body f in
let%bind i' = transpile_annotated_expression i in
return @@ E_iterator (name , f' , i')
return @@ E_iterator (iterator_name , f' , i')
)
| _ -> fail @@ corner_case ~loc:__LOC__ "bad iterator arity"
| [ collection ; initial ; f ] , "FOLD" -> (
let%bind f' = expression_to_iterator_body f in
let%bind initial' = transpile_annotated_expression initial in
let%bind collection' = transpile_annotated_expression collection in
return @@ E_fold (f' , collection' , initial')
)
| _ -> fail @@ corner_case ~loc:__LOC__ ("bad iterator arity:" ^ iterator_name)
in
iterator "ITER" , iterator "MAP" in
let (iter , map , fold) = iterator_generator "ITER" , iterator_generator "MAP" , iterator_generator "FOLD" in
match (name , lst) with
| ("SET_ITER" , lst) -> iter lst
| ("LIST_ITER" , lst) -> iter lst
| ("MAP_ITER" , lst) -> iter lst
| ("LIST_MAP" , lst) -> map lst
| ("MAP_MAP" , lst) -> map lst
| ("LIST_FOLD" , lst) -> fold lst
| _ -> (
let%bind lst' = bind_map_list (transpile_annotated_expression) lst in
return @@ E_constant (name , lst')

View File

@ -339,6 +339,20 @@ and translate_expression (expr:expression) (env:environment) : michelson result
fail error
)
)
| E_fold ((v , body) , collection , initial) -> (
let%bind collection' = translate_expression collection env in
let%bind initial' = translate_expression initial env in
let%bind body' = translate_expression body (Environment.add v env) in
let code = seq [
collection' ;
dip initial' ;
i_iter (seq [
i_swap ;
i_pair ; body' ; dip i_drop ;
]) ;
] in
ok code
)
| E_assignment (name , lrs , expr) -> (
let%bind expr' = translate_expression expr env in
let%bind get_code = Compiler_environment.get env name in

View File

@ -104,7 +104,7 @@ module Typer = struct
let eq_1 a cst = type_value_eq (a , cst)
let eq_2 (a , b) cst = type_value_eq (a , cst) && type_value_eq (b , cst)
let assert_eq_1 a b = Assert.assert_true (eq_1 a b)
let assert_eq_1 ?msg a b = Assert.assert_true ?msg (eq_1 a b)
let comparator : string -> typer = fun s -> typer_2 s @@ fun a b ->
let%bind () =

View File

@ -82,6 +82,7 @@ module Simplify = struct
("set_remove" , "SET_REMOVE") ;
("set_iter" , "SET_ITER") ;
("list_iter" , "LIST_ITER") ;
("list_fold" , "LIST_FOLD") ;
("list_map" , "LIST_MAP") ;
("map_iter" , "MAP_ITER") ;
("map_map" , "MAP_MAP") ;
@ -152,6 +153,8 @@ module Simplify = struct
("Map.update" , "MAP_UPDATE") ;
("Map.add" , "MAP_ADD") ;
("Map.remove" , "MAP_REMOVE") ;
("Map.iter" , "MAP_ITER") ;
("Map.map" , "MAP_MAP") ;
("String.length", "SIZE") ;
("String.size", "SIZE") ;
@ -161,7 +164,9 @@ module Simplify = struct
("List.length", "SIZE") ;
("List.size", "SIZE") ;
("List.iter", "ITER") ;
("List.iter", "LIST_ITER") ;
("List.map" , "LIST_MAP") ;
("List.fold" , "LIST_FOLD") ;
("Operation.transaction" , "CALL") ;
("Operation.get_contract" , "CONTRACT") ;
@ -483,7 +488,21 @@ module Typer = struct
let%bind key = get_t_list lst in
if eq_1 key arg
then ok (t_list res ())
else simple_fail "bad list iter"
else simple_fail "bad list map"
let list_fold = typer_3 "LIST_FOLD" @@ fun lst init body ->
let%bind (arg , res) = get_t_function body in
let%bind (prec , cur) = get_t_pair arg in
let%bind key = get_t_list lst in
let msg = Format.asprintf "%a vs %a"
Ast_typed.PP.type_value key
Ast_typed.PP.type_value arg
in
trace (simple_error ("bad list fold:" ^ msg)) @@
let%bind () = assert_eq_1 ~msg:"key cur" key cur in
let%bind () = assert_eq_1 ~msg:"prec res" prec res in
let%bind () = assert_eq_1 ~msg:"res init" res init in
ok res
let not_ = typer_1 "NOT" @@ fun elt ->
if eq_1 elt (t_bool ())
@ -570,6 +589,7 @@ module Typer = struct
set_iter ;
list_iter ;
list_map ;
list_fold ;
int ;
size ;
failwith_ ;

View File

@ -90,6 +90,8 @@ and expression' ppf (e:expression') = match e with
fprintf ppf "let %s = %a in ( %a )" name expression expr expression body
| E_iterator (s , ((name , _) , body) , expr) ->
fprintf ppf "for_%s %s of %a do ( %a )" s name expression expr expression body
| E_fold (((name , _) , body) , collection , initial) ->
fprintf ppf "fold %a on %a with %s do ( %a )" expression collection expression initial name expression body
| E_assignment (r , path , e) ->
fprintf ppf "%s.%a := %a" r (list_sep lr (const ".")) path expression e
| E_while (e , b) ->

View File

@ -69,6 +69,7 @@ and expression' =
| E_make_empty_set of type_value
| E_make_none of type_value
| E_iterator of (string * ((var_name * type_value) * expression) * expression)
| E_fold of (((var_name * type_value) * expression) * expression * expression)
| E_if_bool of expression * expression * expression
| E_if_none of expression * expression * ((var_name * type_value) * expression)
| E_if_cons of (expression * expression * (((var_name * type_value) * (var_name * type_value)) * expression))

View File

@ -12,3 +12,7 @@ let%entry main (p : param) storage =
[] -> storage
| hd::tl -> storage.(0) + hd, tl
in (([] : operation list), storage)
let fold_op (s : int list) : int =
let aggregate = fun (prec : int) (cur : int) -> prec + cur in
List.fold s 10 aggregate

View File

@ -674,19 +674,21 @@ let match_matej () : unit result =
let mligo_list () : unit result =
let%bind program = mtype_file "./contracts/list.mligo" in
let%bind () =
let make_input n =
e_pair (e_list [e_int n; e_int (2*n)])
(e_pair (e_int 3) (e_list [e_int 8])) in
let make_expected n =
e_pair (e_typed_list [] t_operation)
(e_pair (e_int (n+3)) (e_list [e_int (2*n)]))
in
expect_eq_n program "main" make_input make_expected
in
let%bind () = expect_eq_evaluate program "x" (e_list []) in
let%bind () = expect_eq_evaluate program "y" (e_list @@ List.map e_int [3 ; 4 ; 5]) in
let%bind () = expect_eq_evaluate program "z" (e_list @@ List.map e_int [2 ; 3 ; 4 ; 5]) in
let aux lst = e_list @@ List.map e_int lst in
let%bind () = expect_eq program "fold_op" (aux [ 1 ; 2 ; 3 ]) (e_int 16) in
(* let%bind () =
* let make_input n =
* e_pair (e_list [e_int n; e_int (2*n)])
* (e_pair (e_int 3) (e_list [e_int 8])) in
* let make_expected n =
* e_pair (e_typed_list [] t_operation)
* (e_pair (e_int (n+3)) (e_list [e_int (2*n)]))
* in
* expect_eq_n program "main" make_input make_expected
* in
* let%bind () = expect_eq_evaluate program "x" (e_list []) in
* let%bind () = expect_eq_evaluate program "y" (e_list @@ List.map e_int [3 ; 4 ; 5]) in
* let%bind () = expect_eq_evaluate program "z" (e_list @@ List.map e_int [2 ; 3 ; 4 ; 5]) in *)
ok ()
let lambda_mligo () : unit result =