add map fold
This commit is contained in:
parent
b41b676eb8
commit
c4752c5935
@ -407,6 +407,7 @@ and transpile_annotated_expression (ae:AST.annotated_expression) : expression re
|
|||||||
| ("LIST_MAP" , lst) -> map lst
|
| ("LIST_MAP" , lst) -> map lst
|
||||||
| ("MAP_MAP" , lst) -> map lst
|
| ("MAP_MAP" , lst) -> map lst
|
||||||
| ("LIST_FOLD" , lst) -> fold lst
|
| ("LIST_FOLD" , lst) -> fold lst
|
||||||
|
| ("MAP_FOLD" , lst) -> fold lst
|
||||||
| _ -> (
|
| _ -> (
|
||||||
let%bind lst' = bind_map_list (transpile_annotated_expression) lst in
|
let%bind lst' = bind_map_list (transpile_annotated_expression) lst in
|
||||||
return @@ E_constant (name , lst')
|
return @@ E_constant (name , lst')
|
||||||
|
@ -86,6 +86,7 @@ module Simplify = struct
|
|||||||
("list_map" , "LIST_MAP") ;
|
("list_map" , "LIST_MAP") ;
|
||||||
("map_iter" , "MAP_ITER") ;
|
("map_iter" , "MAP_ITER") ;
|
||||||
("map_map" , "MAP_MAP") ;
|
("map_map" , "MAP_MAP") ;
|
||||||
|
("map_fold" , "MAP_FOLD") ;
|
||||||
("sha_256" , "SHA256") ;
|
("sha_256" , "SHA256") ;
|
||||||
("sha_512" , "SHA512") ;
|
("sha_512" , "SHA512") ;
|
||||||
("blake2b" , "BLAKE2b") ;
|
("blake2b" , "BLAKE2b") ;
|
||||||
@ -155,6 +156,7 @@ module Simplify = struct
|
|||||||
("Map.remove" , "MAP_REMOVE") ;
|
("Map.remove" , "MAP_REMOVE") ;
|
||||||
("Map.iter" , "MAP_ITER") ;
|
("Map.iter" , "MAP_ITER") ;
|
||||||
("Map.map" , "MAP_MAP") ;
|
("Map.map" , "MAP_MAP") ;
|
||||||
|
("Map.fold" , "LIST_FOLD") ;
|
||||||
|
|
||||||
("String.length", "SIZE") ;
|
("String.length", "SIZE") ;
|
||||||
("String.size", "SIZE") ;
|
("String.size", "SIZE") ;
|
||||||
@ -285,16 +287,6 @@ module Typer = struct
|
|||||||
let%bind () = assert_eq_1 arg (t_pair k v ()) in
|
let%bind () = assert_eq_1 arg (t_pair k v ()) in
|
||||||
ok @@ t_map k res ()
|
ok @@ t_map k res ()
|
||||||
|
|
||||||
let map_fold : typer = typer_2 "MAP_FOLD" @@ fun f m ->
|
|
||||||
let%bind (k, v) = get_t_map m in
|
|
||||||
let%bind (arg_1 , res) = get_t_function f in
|
|
||||||
let%bind (arg_2 , res') = get_t_function res in
|
|
||||||
let%bind (arg_3 , res'') = get_t_function res' in
|
|
||||||
let%bind () = assert_eq_1 arg_1 k in
|
|
||||||
let%bind () = assert_eq_1 arg_2 v in
|
|
||||||
let%bind () = assert_eq_1 arg_3 res'' in
|
|
||||||
ok @@ res'
|
|
||||||
|
|
||||||
let size = typer_1 "SIZE" @@ fun t ->
|
let size = typer_1 "SIZE" @@ fun t ->
|
||||||
let%bind () =
|
let%bind () =
|
||||||
Assert.assert_true @@
|
Assert.assert_true @@
|
||||||
@ -504,6 +496,20 @@ module Typer = struct
|
|||||||
let%bind () = assert_eq_1 ~msg:"res init" res init in
|
let%bind () = assert_eq_1 ~msg:"res init" res init in
|
||||||
ok res
|
ok res
|
||||||
|
|
||||||
|
let map_fold = typer_3 "MAP_FOLD" @@ fun map init body ->
|
||||||
|
let%bind (arg , res) = get_t_function body in
|
||||||
|
let%bind (prec , cur) = get_t_pair arg in
|
||||||
|
let%bind (key , value) = get_t_map map in
|
||||||
|
let msg = Format.asprintf "%a vs %a"
|
||||||
|
Ast_typed.PP.type_value key
|
||||||
|
Ast_typed.PP.type_value arg
|
||||||
|
in
|
||||||
|
trace (simple_error ("bad list fold:" ^ msg)) @@
|
||||||
|
let%bind () = assert_eq_1 ~msg:"key cur" (t_pair key value ()) cur in
|
||||||
|
let%bind () = assert_eq_1 ~msg:"prec res" prec res in
|
||||||
|
let%bind () = assert_eq_1 ~msg:"res init" res init in
|
||||||
|
ok res
|
||||||
|
|
||||||
let not_ = typer_1 "NOT" @@ fun elt ->
|
let not_ = typer_1 "NOT" @@ fun elt ->
|
||||||
if eq_1 elt (t_bool ())
|
if eq_1 elt (t_bool ())
|
||||||
then ok @@ t_bool ()
|
then ok @@ t_bool ()
|
||||||
|
@ -44,3 +44,7 @@ function iter_op (const m : foobar) : int is
|
|||||||
function map_op (const m : foobar) : foobar is
|
function map_op (const m : foobar) : foobar is
|
||||||
function increment (const i : int ; const j : int) : int is block { skip } with j + 1 ;
|
function increment (const i : int ; const j : int) : int is block { skip } with j + 1 ;
|
||||||
block { skip } with map_map(m , increment) ;
|
block { skip } with map_map(m , increment) ;
|
||||||
|
|
||||||
|
function fold_op (const m : foobar) : int is
|
||||||
|
function aggregate (const i : int ; const j : (int * int)) : int is block { skip } with i + j.0 + j.1 ;
|
||||||
|
block { skip } with map_fold(m , 10 , aggregate)
|
||||||
|
@ -400,6 +400,11 @@ let map () : unit result =
|
|||||||
let expected = e_int 66 in
|
let expected = e_int 66 in
|
||||||
expect_eq program "iter_op" input expected
|
expect_eq program "iter_op" input expected
|
||||||
in
|
in
|
||||||
|
let%bind () =
|
||||||
|
let input = ez [(1 , 10) ; (2 , 20) ; (3 , 30) ] in
|
||||||
|
let expected = e_int 76 in
|
||||||
|
expect_eq program "fold_op" input expected
|
||||||
|
in
|
||||||
let%bind () =
|
let%bind () =
|
||||||
let input = ez [(1 , 10) ; (2 , 20) ; (3 , 30) ] in
|
let input = ez [(1 , 10) ; (2 , 20) ; (3 , 30) ] in
|
||||||
let expected = ez [(1 , 11) ; (2 , 21) ; (3 , 31) ] in
|
let expected = ez [(1 , 11) ; (2 , 21) ; (3 , 31) ] in
|
||||||
@ -676,19 +681,19 @@ let mligo_list () : unit result =
|
|||||||
let%bind program = mtype_file "./contracts/list.mligo" in
|
let%bind program = mtype_file "./contracts/list.mligo" in
|
||||||
let aux lst = e_list @@ List.map e_int lst in
|
let aux lst = e_list @@ List.map e_int lst in
|
||||||
let%bind () = expect_eq program "fold_op" (aux [ 1 ; 2 ; 3 ]) (e_int 16) in
|
let%bind () = expect_eq program "fold_op" (aux [ 1 ; 2 ; 3 ]) (e_int 16) in
|
||||||
(* let%bind () =
|
let%bind () =
|
||||||
* let make_input n =
|
let make_input n =
|
||||||
* e_pair (e_list [e_int n; e_int (2*n)])
|
e_pair (e_list [e_int n; e_int (2*n)])
|
||||||
* (e_pair (e_int 3) (e_list [e_int 8])) in
|
(e_pair (e_int 3) (e_list [e_int 8])) in
|
||||||
* let make_expected n =
|
let make_expected n =
|
||||||
* e_pair (e_typed_list [] t_operation)
|
e_pair (e_typed_list [] t_operation)
|
||||||
* (e_pair (e_int (n+3)) (e_list [e_int (2*n)]))
|
(e_pair (e_int (n+3)) (e_list [e_int (2*n)]))
|
||||||
* in
|
in
|
||||||
* expect_eq_n program "main" make_input make_expected
|
expect_eq_n program "main" make_input make_expected
|
||||||
* in
|
in
|
||||||
* let%bind () = expect_eq_evaluate program "x" (e_list []) in
|
let%bind () = expect_eq_evaluate program "x" (e_list []) in
|
||||||
* let%bind () = expect_eq_evaluate program "y" (e_list @@ List.map e_int [3 ; 4 ; 5]) in
|
let%bind () = expect_eq_evaluate program "y" (e_list @@ List.map e_int [3 ; 4 ; 5]) in
|
||||||
* let%bind () = expect_eq_evaluate program "z" (e_list @@ List.map e_int [2 ; 3 ; 4 ; 5]) in *)
|
let%bind () = expect_eq_evaluate program "z" (e_list @@ List.map e_int [2 ; 3 ; 4 ; 5]) in
|
||||||
ok ()
|
ok ()
|
||||||
|
|
||||||
let lambda_mligo () : unit result =
|
let lambda_mligo () : unit result =
|
||||||
|
Loading…
Reference in New Issue
Block a user