Started auto-generation of folds on ADTs (part of the code is generated, not all)

This commit is contained in:
Suzanne Dupéron 2019-12-11 17:54:00 +01:00
parent 8f2ff058ec
commit 3605768bb0
7 changed files with 190 additions and 0 deletions

View File

@ -0,0 +1,7 @@
Build with:
dune build adt_generator.a
Run with
python ./generator.py

View File

@ -0,0 +1,16 @@
type root =
| A of a
| B of int
| C of string
and a = {
a1 : ta1 ;
a2 : ta2 ;
}
and ta1 =
| X of root
| Y of ta2
and ta2 =
| Z of ta2

View File

@ -0,0 +1,2 @@
module A = A
module Use_a_fold = Use_a_fold

View File

@ -0,0 +1,6 @@
(library
(name adt_generator)
(public_name ligo.adt_generator)
(libraries
)
)

View File

@ -0,0 +1,70 @@
open A
type root' =
| A' of a'
| B' of int
| C' of string
and a' = {
a1' : ta1' ;
a2' : ta2' ;
}
and ta1' =
| X' of root'
| Y' of ta2'
and ta2' =
| Z' of ta2'
type 'state continue_fold = {
a : a -> 'state -> (a' * 'state) ;
ta1 : ta1 -> 'state -> (ta1' * 'state) ;
ta2 : ta2 -> 'state -> (ta2' * 'state) ;
root : root -> 'state -> (root' * 'state) ;
}
type 'state fold_config = {
root : root -> 'state -> ('state continue_fold) -> (root' * 'state) ;
root_a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ;
root_b : int -> 'state -> ('state continue_fold) -> (int * 'state) ;
root_c : string -> 'state -> ('state continue_fold) -> (string * 'state) ;
a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ;
ta1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ;
ta2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ;
}
let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor ->
{
a = fold_a visitor ;
ta1 = fold_ta1 visitor ;
ta2 = fold_ta2 visitor ;
root = fold_root visitor ;
}
and fold_a : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
match x with
| { a1; a2 } ->
let (a1', state) = visitor.ta1 a1 state continue_fold in
let (a2', state) = visitor.ta2 a2 state continue_fold in
({ a1'; a2' }, state)
and fold_ta2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
match x with
| Z v -> let (v, state) = visitor.ta2 v state continue_fold in (Z' v, state)
and fold_ta1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
match x with
| X v -> let (v, state) = visitor.root v state continue_fold in (X' v , state)
| Y v -> let (v, state) = visitor.ta2 v state continue_fold in (Y' v , state)
and fold_root : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state ->
let continue_fold : state continue_fold = mk_continue_fold visitor in
match x with
| A v -> let (v, state) = visitor.a v state continue_fold in (A' v , state)
| B v -> let (v, state) = visitor.root_b v state continue_fold in (B' v , state)
| C v -> let (v, state) = visitor.root_c v state continue_fold in (C' v , state)
let no_op = failwith "todo"

View File

@ -0,0 +1,71 @@
adts = [
# typename, variant?, fields_or_ctors
("root", True, [
# ctor, builtin, type
("A", False, "a"),
("B", True, "int"),
("C", True, "string"),
]),
("a", False, [
("a1", False, "ta1"),
("a2", False, "ta2"),
]),
("ta1", True, [
("X", False, "root"),
("Y", False, "ta2"),
]),
("ta2", True, [
("Z", False, "ta2"),
]),
]
print "type 'state fold_config = {"
for (t, is_variant, ctors) in adts:
tt = ("%s'" % (t,)) # output type t'
print (" %s : %s -> 'state -> ('state continue_fold) -> (%s * 'state) ;" % (t, t, tt))
for (c, builtin, ct,) in ctors:
if builtin:
ctt = ct # TODO: use a wrapper instead of a' for the intermediate steps, and target a different type a' just to change what the output type is
else:
ctt = ("%s'" % (ct,))
print (" %s_%s : %s -> 'state -> ('state continue_fold) -> (%s * 'state) ;" % (t, c, ct, ctt))
print " }"
print ""
print "let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor ->"
print " {"
for (t, is_variant, ctors) in adts:
print (" %s = fold_%s visitor ;" % (t, t))
print " }"
print ""
for (t, is_variant, ctors) in adts:
v = t # visitor field
tt = ("%s'" % (t,)) # output type t'
print ("and fold_%s : type state . state fold_config -> %s -> state -> (%s * state) = fun visitor x state ->" % (t, t, tt,))
print " let continue_fold : state continue_fold = mk_continue_fold visitor in"
print " match x with"
if is_variant:
for (c, builtin, ct,) in ctors:
cc = ("%s'" % (c,))
print (" | %s v ->" % (c,))
print (" let (v, state) = visitor.%s_%s v state continue_fold in" % (t, c,))
if not builtin:
print (" let (v, state) = visitor.%s v state continue_fold in" % (ct,))
print (" (%s v, state)" % (cc,))
else:
print " | {"
for (f, builtin, ft,) in ctors:
print (" %s;" % (f,))
print " } ->"
for (f, builtin, ft,) in ctors:
ff = ("%s'" % (f,))
print (" let (%s, state) = visitor.%s_%s %s state continue_fold in" % (f, t, f, f,))
if not builtin:
print (" let (%s, state) = visitor.%s %s state continue_fold in" % (ff, ft, f,))
print " ({"
for (f, builtin, ft,) in ctors:
ff = ("%s'" % (f,))
print (" %s;" % (ff,))
print " }, state)"
print ""

View File

@ -0,0 +1,18 @@
open A
open Fold
let _ =
let some_root = ((failwith "assume we have some root") : root) in
let op = {
no_op with
a = fun the_a state continue_fold ->
let (a1' , state') = continue_fold.ta1 the_a.a1 state in
let (a2' , state'') = continue_fold.ta2 the_a.a2 state' in
({
a1' = a1' ;
a2' = a2' ;
}, state'')
} in
let state = () in
fold_root op some_root state