add map fold

This commit is contained in:
galfour 2019-09-23 23:46:47 +02:00
parent b41b676eb8
commit c4752c5935
4 changed files with 39 additions and 23 deletions

View File

@ -407,6 +407,7 @@ and transpile_annotated_expression (ae:AST.annotated_expression) : expression re
| ("LIST_MAP" , lst) -> map lst | ("LIST_MAP" , lst) -> map lst
| ("MAP_MAP" , lst) -> map lst | ("MAP_MAP" , lst) -> map lst
| ("LIST_FOLD" , lst) -> fold lst | ("LIST_FOLD" , lst) -> fold lst
| ("MAP_FOLD" , lst) -> fold lst
| _ -> ( | _ -> (
let%bind lst' = bind_map_list (transpile_annotated_expression) lst in let%bind lst' = bind_map_list (transpile_annotated_expression) lst in
return @@ E_constant (name , lst') return @@ E_constant (name , lst')

View File

@ -86,6 +86,7 @@ module Simplify = struct
("list_map" , "LIST_MAP") ; ("list_map" , "LIST_MAP") ;
("map_iter" , "MAP_ITER") ; ("map_iter" , "MAP_ITER") ;
("map_map" , "MAP_MAP") ; ("map_map" , "MAP_MAP") ;
("map_fold" , "MAP_FOLD") ;
("sha_256" , "SHA256") ; ("sha_256" , "SHA256") ;
("sha_512" , "SHA512") ; ("sha_512" , "SHA512") ;
("blake2b" , "BLAKE2b") ; ("blake2b" , "BLAKE2b") ;
@ -155,6 +156,7 @@ module Simplify = struct
("Map.remove" , "MAP_REMOVE") ; ("Map.remove" , "MAP_REMOVE") ;
("Map.iter" , "MAP_ITER") ; ("Map.iter" , "MAP_ITER") ;
("Map.map" , "MAP_MAP") ; ("Map.map" , "MAP_MAP") ;
("Map.fold" , "LIST_FOLD") ;
("String.length", "SIZE") ; ("String.length", "SIZE") ;
("String.size", "SIZE") ; ("String.size", "SIZE") ;
@ -285,16 +287,6 @@ module Typer = struct
let%bind () = assert_eq_1 arg (t_pair k v ()) in let%bind () = assert_eq_1 arg (t_pair k v ()) in
ok @@ t_map k res () ok @@ t_map k res ()
let map_fold : typer = typer_2 "MAP_FOLD" @@ fun f m ->
let%bind (k, v) = get_t_map m in
let%bind (arg_1 , res) = get_t_function f in
let%bind (arg_2 , res') = get_t_function res in
let%bind (arg_3 , res'') = get_t_function res' in
let%bind () = assert_eq_1 arg_1 k in
let%bind () = assert_eq_1 arg_2 v in
let%bind () = assert_eq_1 arg_3 res'' in
ok @@ res'
let size = typer_1 "SIZE" @@ fun t -> let size = typer_1 "SIZE" @@ fun t ->
let%bind () = let%bind () =
Assert.assert_true @@ Assert.assert_true @@
@ -504,6 +496,20 @@ module Typer = struct
let%bind () = assert_eq_1 ~msg:"res init" res init in let%bind () = assert_eq_1 ~msg:"res init" res init in
ok res ok res
let map_fold = typer_3 "MAP_FOLD" @@ fun map init body ->
let%bind (arg , res) = get_t_function body in
let%bind (prec , cur) = get_t_pair arg in
let%bind (key , value) = get_t_map map in
let msg = Format.asprintf "%a vs %a"
Ast_typed.PP.type_value key
Ast_typed.PP.type_value arg
in
trace (simple_error ("bad list fold:" ^ msg)) @@
let%bind () = assert_eq_1 ~msg:"key cur" (t_pair key value ()) cur in
let%bind () = assert_eq_1 ~msg:"prec res" prec res in
let%bind () = assert_eq_1 ~msg:"res init" res init in
ok res
let not_ = typer_1 "NOT" @@ fun elt -> let not_ = typer_1 "NOT" @@ fun elt ->
if eq_1 elt (t_bool ()) if eq_1 elt (t_bool ())
then ok @@ t_bool () then ok @@ t_bool ()

View File

@ -44,3 +44,7 @@ function iter_op (const m : foobar) : int is
function map_op (const m : foobar) : foobar is function map_op (const m : foobar) : foobar is
function increment (const i : int ; const j : int) : int is block { skip } with j + 1 ; function increment (const i : int ; const j : int) : int is block { skip } with j + 1 ;
block { skip } with map_map(m , increment) ; block { skip } with map_map(m , increment) ;
function fold_op (const m : foobar) : int is
function aggregate (const i : int ; const j : (int * int)) : int is block { skip } with i + j.0 + j.1 ;
block { skip } with map_fold(m , 10 , aggregate)

View File

@ -400,6 +400,11 @@ let map () : unit result =
let expected = e_int 66 in let expected = e_int 66 in
expect_eq program "iter_op" input expected expect_eq program "iter_op" input expected
in in
let%bind () =
let input = ez [(1 , 10) ; (2 , 20) ; (3 , 30) ] in
let expected = e_int 76 in
expect_eq program "fold_op" input expected
in
let%bind () = let%bind () =
let input = ez [(1 , 10) ; (2 , 20) ; (3 , 30) ] in let input = ez [(1 , 10) ; (2 , 20) ; (3 , 30) ] in
let expected = ez [(1 , 11) ; (2 , 21) ; (3 , 31) ] in let expected = ez [(1 , 11) ; (2 , 21) ; (3 , 31) ] in
@ -676,19 +681,19 @@ let mligo_list () : unit result =
let%bind program = mtype_file "./contracts/list.mligo" in let%bind program = mtype_file "./contracts/list.mligo" in
let aux lst = e_list @@ List.map e_int lst in let aux lst = e_list @@ List.map e_int lst in
let%bind () = expect_eq program "fold_op" (aux [ 1 ; 2 ; 3 ]) (e_int 16) in let%bind () = expect_eq program "fold_op" (aux [ 1 ; 2 ; 3 ]) (e_int 16) in
(* let%bind () = let%bind () =
* let make_input n = let make_input n =
* e_pair (e_list [e_int n; e_int (2*n)]) e_pair (e_list [e_int n; e_int (2*n)])
* (e_pair (e_int 3) (e_list [e_int 8])) in (e_pair (e_int 3) (e_list [e_int 8])) in
* let make_expected n = let make_expected n =
* e_pair (e_typed_list [] t_operation) e_pair (e_typed_list [] t_operation)
* (e_pair (e_int (n+3)) (e_list [e_int (2*n)])) (e_pair (e_int (n+3)) (e_list [e_int (2*n)]))
* in in
* expect_eq_n program "main" make_input make_expected expect_eq_n program "main" make_input make_expected
* in in
* let%bind () = expect_eq_evaluate program "x" (e_list []) in let%bind () = expect_eq_evaluate program "x" (e_list []) in
* let%bind () = expect_eq_evaluate program "y" (e_list @@ List.map e_int [3 ; 4 ; 5]) in let%bind () = expect_eq_evaluate program "y" (e_list @@ List.map e_int [3 ; 4 ; 5]) in
* let%bind () = expect_eq_evaluate program "z" (e_list @@ List.map e_int [2 ; 3 ; 4 ; 5]) in *) let%bind () = expect_eq_evaluate program "z" (e_list @@ List.map e_int [2 ; 3 ; 4 ; 5]) in
ok () ok ()
let lambda_mligo () : unit result = let lambda_mligo () : unit result =