Merge branch 'feature/auto-generate-folds' into 'dev'

Automatic generation of the boilerplate for folds

See merge request ligolang/ligo!303
This commit is contained in:
Suzanne Dupéron 2020-01-10 00:24:55 +00:00
commit e7589f1a6a
7 changed files with 410 additions and 0 deletions

View File

@ -0,0 +1,7 @@
Build with:
dune build adt_generator.a
Run with
python ./generator.py

View File

@ -0,0 +1,17 @@
type root =
| A of a
| B of int
| C of string
and a = {
a1 : ta1 ;
a2 : ta2 ;
}
and ta1 =
| X of root
| Y of ta2
and ta2 =
| Z of ta2
| W of unit

View File

@ -0,0 +1,2 @@
module A = A
module Use_a_fold = Use_a_fold

View File

@ -0,0 +1,18 @@
(rule
(target fold.ml)
(deps generator.py)
(action (with-stdout-to fold.ml (run python3 ./generator.py)))
(mode (promote (until-clean))))
; (library
; (name adt_generator)
; (public_name ligo.adt_generator)
; (libraries
; )
; )
(executable
(name adt_generator)
(public_name ligo.adt_generator)
(libraries
)
)

View File

@ -0,0 +1,184 @@
open A
type root' =
| A' of a'
| B' of int
| C' of string
and a' =
{
a1' : ta1' ;
a2' : ta2' ;
}
and ta1' =
| X' of root'
| Y' of ta2'
and ta2' =
| Z' of ta2'
| W' of unit
type 'state continue_fold =
{
root : root -> 'state -> (root' * 'state) ;
root_A : a -> 'state -> (a' * 'state) ;
root_B : int -> 'state -> (int * 'state) ;
root_C : string -> 'state -> (string * 'state) ;
a : a -> 'state -> (a' * 'state) ;
a_a1 : ta1 -> 'state -> (ta1' * 'state) ;
a_a2 : ta2 -> 'state -> (ta2' * 'state) ;
ta1 : ta1 -> 'state -> (ta1' * 'state) ;
ta1_X : root -> 'state -> (root' * 'state) ;
ta1_Y : ta2 -> 'state -> (ta2' * 'state) ;
ta2 : ta2 -> 'state -> (ta2' * 'state) ;
ta2_Z : ta2 -> 'state -> (ta2' * 'state) ;
ta2_W : unit -> 'state -> (unit * 'state) ;
}
type 'state fold_config =
{
root : root -> 'state -> ('state continue_fold) -> (root' * 'state) ;
root_pre_state : root -> 'state -> 'state ;
root_post_state : root -> root' -> 'state -> 'state ;
root_A : a -> 'state -> ('state continue_fold) -> (a' * 'state) ;
root_B : int -> 'state -> ('state continue_fold) -> (int * 'state) ;
root_C : string -> 'state -> ('state continue_fold) -> (string * 'state) ;
a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ;
a_pre_state : a -> 'state -> 'state ;
a_post_state : a -> a' -> 'state -> 'state ;
a_a1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ;
a_a2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ;
ta1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ;
ta1_pre_state : ta1 -> 'state -> 'state ;
ta1_post_state : ta1 -> ta1' -> 'state -> 'state ;
ta1_X : root -> 'state -> ('state continue_fold) -> (root' * 'state) ;
ta1_Y : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ;
ta2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ;
ta2_pre_state : ta2 -> 'state -> 'state ;
ta2_post_state : ta2 -> ta2' -> 'state -> 'state ;
ta2_Z : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ;
ta2_W : unit -> 'state -> ('state continue_fold) -> (unit * 'state) ;
}
(* Curries the "visitor" argument to the folds (non-customizable traversal functions). *)
let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor ->
{
root = fold_root visitor ;
root_A = fold_root_A visitor ;
root_B = fold_root_B visitor ;
root_C = fold_root_C visitor ;
a = fold_a visitor ;
a_a1 = fold_a_a1 visitor ;
a_a2 = fold_a_a2 visitor ;
ta1 = fold_ta1 visitor ;
ta1_X = fold_ta1_X visitor ;
ta1_Y = fold_ta1_Y visitor ;
ta2 = fold_ta2 visitor ;
ta2_Z = fold_ta2_Z visitor ;
ta2_W = fold_ta2_W visitor ;
}
and fold_root : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
let state = visitor.root_pre_state x state in
let (new_x, state) = visitor.root x state continue_fold in
let state = visitor.root_post_state x new_x state in
(new_x, state)
and fold_root_A : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
visitor.root_A x state continue_fold
and fold_root_B : type state . state fold_config -> int -> state -> (int * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
visitor.root_B x state continue_fold
and fold_root_C : type state . state fold_config -> string -> state -> (string * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
visitor.root_C x state continue_fold
and fold_a : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
let state = visitor.a_pre_state x state in
let (new_x, state) = visitor.a x state continue_fold in
let state = visitor.a_post_state x new_x state in
(new_x, state)
and fold_a_a1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
visitor.a_a1 x state continue_fold
and fold_a_a2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
visitor.a_a2 x state continue_fold
and fold_ta1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
let state = visitor.ta1_pre_state x state in
let (new_x, state) = visitor.ta1 x state continue_fold in
let state = visitor.ta1_post_state x new_x state in
(new_x, state)
and fold_ta1_X : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
visitor.ta1_X x state continue_fold
and fold_ta1_Y : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
visitor.ta1_Y x state continue_fold
and fold_ta2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
let state = visitor.ta2_pre_state x state in
let (new_x, state) = visitor.ta2 x state continue_fold in
let state = visitor.ta2_post_state x new_x state in
(new_x, state)
and fold_ta2_Z : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
visitor.ta2_Z x state continue_fold
and fold_ta2_W : type state . state fold_config -> unit -> state -> (unit * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
visitor.ta2_W x state continue_fold
let no_op : 'a fold_config = {
root = (fun v state continue ->
match v with
| A v -> let (v, state) = continue.root_A v state in (A' v, state)
| B v -> let (v, state) = continue.root_B v state in (B' v, state)
| C v -> let (v, state) = continue.root_C v state in (C' v, state)
);
root_pre_state = (fun v state -> ignore v; state) ;
root_post_state = (fun v new_v state -> ignore (v, new_v); state) ;
root_A = (fun v state continue -> continue.a v state ) ;
root_B = (fun v state continue -> ignore continue; (v, state) ) ;
root_C = (fun v state continue -> ignore continue; (v, state) ) ;
a = (fun v state continue ->
match v with
{ a1; a2; } ->
let (a1', state) = continue.a_a1 a1 state in
let (a2', state) = continue.a_a2 a2 state in
({ a1'; a2'; }, state)
);
a_pre_state = (fun v state -> ignore v; state) ;
a_post_state = (fun v new_v state -> ignore (v, new_v); state) ;
a_a1 = (fun v state continue -> continue.ta1 v state ) ;
a_a2 = (fun v state continue -> continue.ta2 v state ) ;
ta1 = (fun v state continue ->
match v with
| X v -> let (v, state) = continue.ta1_X v state in (X' v, state)
| Y v -> let (v, state) = continue.ta1_Y v state in (Y' v, state)
);
ta1_pre_state = (fun v state -> ignore v; state) ;
ta1_post_state = (fun v new_v state -> ignore (v, new_v); state) ;
ta1_X = (fun v state continue -> continue.root v state ) ;
ta1_Y = (fun v state continue -> continue.ta2 v state ) ;
ta2 = (fun v state continue ->
match v with
| Z v -> let (v, state) = continue.ta2_Z v state in (Z' v, state)
| W v -> let (v, state) = continue.ta2_W v state in (W' v, state)
);
ta2_pre_state = (fun v state -> ignore v; state) ;
ta2_post_state = (fun v new_v state -> ignore (v, new_v); state) ;
ta2_Z = (fun v state continue -> continue.ta2 v state ) ;
ta2_W = (fun v state continue -> ignore continue; (v, state) ) ;
}

View File

@ -0,0 +1,134 @@
moduleName = "A"
adts = [
# typename, variant?, fields_or_ctors
("root", True, [
# ctor, builtin, type
("A", False, "a"),
("B", True, "int"),
("C", True, "string"),
]),
("a", False, [
("a1", False, "ta1"),
("a2", False, "ta2"),
]),
("ta1", True, [
("X", False, "root"),
("Y", False, "ta2"),
]),
("ta2", True, [
("Z", False, "ta2"),
("W", True, "unit"),
]),
]
from collections import namedtuple
adt = namedtuple('adt', ['name', 'newName', 'isVariant', 'ctorsOrFields'])
ctorOrField = namedtuple('ctorOrField', ['name', 'newName', 'isBuiltin', 'type_', 'newType'])
adts = [
adt(
name = name,
newName = f"{name}'",
isVariant = isVariant,
ctorsOrFields = [
ctorOrField(
name = cf,
newName = f"{cf}'",
isBuiltin = isBuiltin,
type_ = type_,
newType = type_ if isBuiltin else f"{type_}'",
)
for (cf, isBuiltin, type_) in ctors
],
)
for (name, isVariant, ctors) in adts
]
print("open %s" % moduleName)
print("")
for (index, t) in enumerate(adts):
typeOrAnd = "type" if index == 0 else "and"
print(f"{typeOrAnd} {t.newName} =")
if t.isVariant:
for c in t.ctorsOrFields:
print(f" | {c.newName} of {c.newType}")
else:
print(" {")
for f in t.ctorsOrFields:
print(f" {f.newName} : {f.newType} ;")
print(" }")
print("")
print(f"type 'state continue_fold =")
print(" {")
for t in adts:
print(f" {t.name} : {t.name} -> 'state -> ({t.newName} * 'state) ;")
for c in t.ctorsOrFields:
print(f" {t.name}_{c.name} : {c.type_} -> 'state -> ({c.newType} * 'state) ;")
print(" }")
print("")
print(f"type 'state fold_config =")
print(" {")
for t in adts:
print(f" {t.name} : {t.name} -> 'state -> ('state continue_fold) -> ({t.newName} * 'state) ;")
print(f" {t.name}_pre_state : {t.name} -> 'state -> 'state ;")
print(f" {t.name}_post_state : {t.name} -> {t.newName} -> 'state -> 'state ;")
for c in t.ctorsOrFields:
print(f" {t.name}_{c.name} : {c.type_} -> 'state -> ('state continue_fold) -> ({c.newType} * 'state) ;")
print(" }")
print("")
print('(* Curries the "visitor" argument to the folds (non-customizable traversal functions). *)')
print("let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor ->")
print(" {")
for t in adts:
print(f" {t.name} = fold_{t.name} visitor ;")
for c in t.ctorsOrFields:
print(f" {t.name}_{c.name} = fold_{t.name}_{c.name} visitor ;")
print("}")
print("")
for t in adts:
print(f"and fold_{t.name} : type state . state fold_config -> {t.name} -> state -> ({t.newName} * state) = fun visitor x state ->")
print(" let continue_fold : state continue_fold = mk_continue_fold visitor in")
print(f" let state = visitor.{t.name}_pre_state x state in")
print(f" let (new_x, state) = visitor.{t.name} x state continue_fold in")
print(f" let state = visitor.{t.name}_post_state x new_x state in")
print(" (new_x, state)")
print("")
for c in t.ctorsOrFields:
print(f"and fold_{t.name}_{c.name} : type state . state fold_config -> {c.type_} -> state -> ({c.newType} * state) = fun visitor x state ->")
print(" let continue_fold : state continue_fold = mk_continue_fold visitor in")
print(f" visitor.{t.name}_{c.name} x state continue_fold")
print("")
print("let no_op : 'a fold_config = {")
for t in adts:
print(f" {t.name} = (fun v state continue ->")
print(" match v with")
if t.isVariant:
for c in t.ctorsOrFields:
print(f" | {c.name} v -> let (v, state) = continue.{t.name}_{c.name} v state in ({c.newName} v, state)")
else:
print(" {", end=' ')
for f in t.ctorsOrFields:
print(f"{f.name};", end=' ')
print("} ->")
for f in t.ctorsOrFields:
print(f" let ({f.newName}, state) = continue.{t.name}_{f.name} {f.name} state in")
print(" ({", end=' ')
for f in t.ctorsOrFields:
print(f"{f.newName};", end=' ')
print("}, state)")
print(" );")
print(f" {t.name}_pre_state = (fun v state -> ignore v; state) ;")
print(f" {t.name}_post_state = (fun v new_v state -> ignore (v, new_v); state) ;")
for c in t.ctorsOrFields:
print(f" {t.name}_{c.name} = (fun v state continue ->", end=' ')
if c.isBuiltin:
print("ignore continue; (v, state)", end=' ')
else:
print(f"continue.{c.type_} v state", end=' ')
print(") ;")
print("}")

View File

@ -0,0 +1,48 @@
open A
open Fold
(* TODO: how should we plug these into our test framework? *)
let () =
let some_root : root = A { a1 = X (A { a1 = X (B 1) ; a2 = W () ; }) ; a2 = Z (W ()) ; } in
let op = {
no_op with
a = fun the_a state continue_fold ->
let (a1' , state') = continue_fold.ta1 the_a.a1 state in
let (a2' , state'') = continue_fold.ta2 the_a.a2 state' in
({
a1' = a1' ;
a2' = a2' ;
}, state'' + 1)
} in
let state = 0 in
let (_, state) = fold_root op some_root state in
if state != 2 then
failwith (Printf.sprintf "Test failed: expected folder to count 2 nodes, but it counted %d nodes" state)
else
()
let () =
let some_root : root = A { a1 = X (A { a1 = X (B 1) ; a2 = W () ; }) ; a2 = Z (W ()) ; } in
let op = { no_op with a_pre_state = fun _the_a state -> state + 1 } in
let state = 0 in
let (_, state) = fold_root op some_root state in
if state != 2 then
failwith (Printf.sprintf "Test failed: expected folder to count 2 nodes, but it counted %d nodes" state)
else
()
let () =
let some_root : root = A { a1 = X (A { a1 = X (B 1) ; a2 = W () ; }) ; a2 = Z (W ()) ; } in
let op = { no_op with a_post_state = fun _the_a _new_a state -> state + 1 } in
let state = 0 in
let (_, state) = fold_root op some_root state in
if state != 2 then
failwith (Printf.sprintf "Test failed: expected folder to count 2 nodes, but it counted %d nodes" state)
else
()
(* Test that the same fold_config can be ascibed with different 'a type arguments *)
let _noi : int fold_config = no_op (* (fun _ -> ()) *)
let _nob : bool fold_config = no_op (* (fun _ -> ()) *)