From 8b98898dbf34a508d2121f7972300325fe8b220d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Suzanne=20Dup=C3=A9ron?= Date: Thu, 12 Dec 2019 12:12:34 +0100 Subject: [PATCH] first executable version of the auto-generation of folds --- src/stages/adt_generator/a.ml | 1 + src/stages/adt_generator/dune | 14 +- src/stages/adt_generator/fold.ml | 196 ++++++++++++++++------- src/stages/adt_generator/generator.py | 207 +++++++++++++++++++------ src/stages/adt_generator/use_a_fold.ml | 12 +- 5 files changed, 323 insertions(+), 107 deletions(-) diff --git a/src/stages/adt_generator/a.ml b/src/stages/adt_generator/a.ml index c30254e99..f1d8b2fb1 100644 --- a/src/stages/adt_generator/a.ml +++ b/src/stages/adt_generator/a.ml @@ -14,3 +14,4 @@ and ta1 = and ta2 = | Z of ta2 +| W of unit diff --git a/src/stages/adt_generator/dune b/src/stages/adt_generator/dune index 0e1a15f71..4a52c6088 100644 --- a/src/stages/adt_generator/dune +++ b/src/stages/adt_generator/dune @@ -1,4 +1,16 @@ -(library +(rule + (target fold.ml) + (deps generator.py) + (action (with-stdout-to fold.ml (run python3 ./generator.py))) + (mode (promote (until-clean)))) +; (library +; (name adt_generator) +; (public_name ligo.adt_generator) +; (libraries +; ) +; ) + +(executable (name adt_generator) (public_name ligo.adt_generator) (libraries diff --git a/src/stages/adt_generator/fold.ml b/src/stages/adt_generator/fold.ml index f46ca6b5e..72d64e0a4 100644 --- a/src/stages/adt_generator/fold.ml +++ b/src/stages/adt_generator/fold.ml @@ -1,70 +1,156 @@ open A type root' = -| A' of a' -| B' of int -| C' of string - -and a' = { - a1' : ta1' ; - a2' : ta2' ; -} - + | A' of a' + | B' of int + | C' of string +and a' = + { + a1' : ta1' ; + a2' : ta2' ; + } and ta1' = -| X' of root' -| Y' of ta2' - + | X' of root' + | Y' of ta2' and ta2' = -| Z' of ta2' + | Z' of ta2' + | W' of unit -type 'state continue_fold = { - a : a -> 'state -> (a' * 'state) ; - ta1 : ta1 -> 'state -> (ta1' * 'state) ; - ta2 : ta2 -> 'state -> (ta2' * 'state) ; +type 'state continue_fold = + { root : root -> 'state -> (root' * 'state) ; -} - -type 'state fold_config = { - root : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; - root_a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; - root_b : int -> 'state -> ('state continue_fold) -> (int * 'state) ; - root_c : string -> 'state -> ('state continue_fold) -> (string * 'state) ; - a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; - ta1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; - ta2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + root_A : a -> 'state -> (a' * 'state) ; + root_B : int -> 'state -> (int * 'state) ; + root_C : string -> 'state -> (string * 'state) ; + a : a -> 'state -> (a' * 'state) ; + a_a1 : ta1 -> 'state -> (ta1' * 'state) ; + a_a2 : ta2 -> 'state -> (ta2' * 'state) ; + ta1 : ta1 -> 'state -> (ta1' * 'state) ; + ta1_X : root -> 'state -> (root' * 'state) ; + ta1_Y : ta2 -> 'state -> (ta2' * 'state) ; + ta2 : ta2 -> 'state -> (ta2' * 'state) ; + ta2_Z : ta2 -> 'state -> (ta2' * 'state) ; + ta2_W : unit -> 'state -> (unit * 'state) ; } +type 'state fold_config = + { + root : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; + root_A : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; + root_B : int -> 'state -> ('state continue_fold) -> (int * 'state) ; + root_C : string -> 'state -> ('state continue_fold) -> (string * 'state) ; + a : a -> 'state -> ('state continue_fold) -> (a' * 'state) ; + a_a1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; + a_a2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta1 : ta1 -> 'state -> ('state continue_fold) -> (ta1' * 'state) ; + ta1_X : root -> 'state -> ('state continue_fold) -> (root' * 'state) ; + ta1_Y : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta2 : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta2_Z : ta2 -> 'state -> ('state continue_fold) -> (ta2' * 'state) ; + ta2_W : unit -> 'state -> ('state continue_fold) -> (unit * 'state) ; + } + +(* Curries the "visitor" argument to the folds (non-customizable traversal functions). *) let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor -> { - a = fold_a visitor ; - ta1 = fold_ta1 visitor ; - ta2 = fold_ta2 visitor ; root = fold_root visitor ; - } - -and fold_a : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state -> - let continue_fold : state continue_fold = mk_continue_fold visitor in - match x with - | { a1; a2 } -> - let (a1', state) = visitor.ta1 a1 state continue_fold in - let (a2', state) = visitor.ta2 a2 state continue_fold in - ({ a1'; a2' }, state) - -and fold_ta2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> - let continue_fold : state continue_fold = mk_continue_fold visitor in - match x with - | Z v -> let (v, state) = visitor.ta2 v state continue_fold in (Z' v, state) - -and fold_ta1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state -> - let continue_fold : state continue_fold = mk_continue_fold visitor in - match x with - | X v -> let (v, state) = visitor.root v state continue_fold in (X' v , state) - | Y v -> let (v, state) = visitor.ta2 v state continue_fold in (Y' v , state) + root_A = fold_root_A visitor ; + root_B = fold_root_B visitor ; + root_C = fold_root_C visitor ; + a = fold_a visitor ; + a_a1 = fold_a_a1 visitor ; + a_a2 = fold_a_a2 visitor ; + ta1 = fold_ta1 visitor ; + ta1_X = fold_ta1_X visitor ; + ta1_Y = fold_ta1_Y visitor ; + ta2 = fold_ta2 visitor ; + ta2_Z = fold_ta2_Z visitor ; + ta2_W = fold_ta2_W visitor ; +} and fold_root : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state -> let continue_fold : state continue_fold = mk_continue_fold visitor in - match x with - | A v -> let (v, state) = visitor.a v state continue_fold in (A' v , state) - | B v -> let (v, state) = visitor.root_b v state continue_fold in (B' v , state) - | C v -> let (v, state) = visitor.root_c v state continue_fold in (C' v , state) -let no_op = failwith "todo" + visitor.root x state continue_fold + +and fold_root_A : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.root_A x state continue_fold + +and fold_root_B : type state . state fold_config -> int -> state -> (int * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.root_B x state continue_fold + +and fold_root_C : type state . state fold_config -> string -> state -> (string * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.root_C x state continue_fold + +and fold_a : type state . state fold_config -> a -> state -> (a' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.a x state continue_fold + +and fold_a_a1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.a_a1 x state continue_fold + +and fold_a_a2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.a_a2 x state continue_fold + +and fold_ta1 : type state . state fold_config -> ta1 -> state -> (ta1' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta1 x state continue_fold + +and fold_ta1_X : type state . state fold_config -> root -> state -> (root' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta1_X x state continue_fold + +and fold_ta1_Y : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta1_Y x state continue_fold + +and fold_ta2 : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta2 x state continue_fold + +and fold_ta2_Z : type state . state fold_config -> ta2 -> state -> (ta2' * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta2_Z x state continue_fold + +and fold_ta2_W : type state . state fold_config -> unit -> state -> (unit * state) = fun visitor x state -> + let continue_fold : state continue_fold = mk_continue_fold visitor in + visitor.ta2_W x state continue_fold + +let no_op : 'a fold_config = { + root = (fun v state continue -> + match v with + | A v -> let (v, state) = continue.root_A v state in (A' v, state) + | B v -> let (v, state) = continue.root_B v state in (B' v, state) + | C v -> let (v, state) = continue.root_C v state in (C' v, state) + ); + root_A = (fun v state continue -> continue.a v state ) ; + root_B = (fun v state continue -> ignore continue; (v, state) ) ; + root_C = (fun v state continue -> ignore continue; (v, state) ) ; + a = (fun v state continue -> + match v with + { a1; a2; } -> + let (a1', state) = continue.a_a1 a1 state in + let (a2', state) = continue.a_a2 a2 state in + ({ a1'; a2'; }, state) + ); + a_a1 = (fun v state continue -> continue.ta1 v state ) ; + a_a2 = (fun v state continue -> continue.ta2 v state ) ; + ta1 = (fun v state continue -> + match v with + | X v -> let (v, state) = continue.ta1_X v state in (X' v, state) + | Y v -> let (v, state) = continue.ta1_Y v state in (Y' v, state) + ); + ta1_X = (fun v state continue -> continue.root v state ) ; + ta1_Y = (fun v state continue -> continue.ta2 v state ) ; + ta2 = (fun v state continue -> + match v with + | Z v -> let (v, state) = continue.ta2_Z v state in (Z' v, state) + | W v -> let (v, state) = continue.ta2_W v state in (W' v, state) + ); + ta2_Z = (fun v state continue -> continue.ta2 v state ) ; + ta2_W = (fun v state continue -> ignore continue; (v, state) ) ; +} diff --git a/src/stages/adt_generator/generator.py b/src/stages/adt_generator/generator.py index ad537f4d7..2c8ea5dbf 100644 --- a/src/stages/adt_generator/generator.py +++ b/src/stages/adt_generator/generator.py @@ -1,3 +1,4 @@ +moduleName = "A" adts = [ # typename, variant?, fields_or_ctors ("root", True, [ @@ -16,56 +17,168 @@ adts = [ ]), ("ta2", True, [ ("Z", False, "ta2"), + ("W", True, "unit"), ]), ] -print "type 'state fold_config = {" -for (t, is_variant, ctors) in adts: - tt = ("%s'" % (t,)) # output type t' - print (" %s : %s -> 'state -> ('state continue_fold) -> (%s * 'state) ;" % (t, t, tt)) - for (c, builtin, ct,) in ctors: - if builtin: - ctt = ct # TODO: use a wrapper instead of a' for the intermediate steps, and target a different type a' just to change what the output type is - else: - ctt = ("%s'" % (ct,)) - print (" %s_%s : %s -> 'state -> ('state continue_fold) -> (%s * 'state) ;" % (t, c, ct, ctt)) -print " }" -print "" +from collections import namedtuple +adt = namedtuple('adt', ['name', 'newName', 'isVariant', 'ctorsOrFields']) +ctorOrField = namedtuple('ctorOrField', ['name', 'newName', 'isBuiltin', 'type_', 'newType']) +adts = [ + adt( + name = name, + newName = f"{name}'", + isVariant = isVariant, + ctorsOrFields = [ + ctorOrField( + name = cf, + newName = f"{cf}'", + isBuiltin = isBuiltin, + type_ = type_, + newType = type_ if isBuiltin else f"{type_}'", + ) + for (cf, isBuiltin, type_) in ctors + ], + ) + for (name, isVariant, ctors) in adts +] -print "let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor ->" -print " {" -for (t, is_variant, ctors) in adts: - print (" %s = fold_%s visitor ;" % (t, t)) -print " }" -print "" +print("open %s" % moduleName) -for (t, is_variant, ctors) in adts: - v = t # visitor field - tt = ("%s'" % (t,)) # output type t' - print ("and fold_%s : type state . state fold_config -> %s -> state -> (%s * state) = fun visitor x state ->" % (t, t, tt,)) - print " let continue_fold : state continue_fold = mk_continue_fold visitor in" - print " match x with" - if is_variant: - for (c, builtin, ct,) in ctors: - cc = ("%s'" % (c,)) - print (" | %s v ->" % (c,)) - print (" let (v, state) = visitor.%s_%s v state continue_fold in" % (t, c,)) - if not builtin: - print (" let (v, state) = visitor.%s v state continue_fold in" % (ct,)) - print (" (%s v, state)" % (cc,)) +print("") +for (index, t) in enumerate(adts): + typeOrAnd = "type" if index == 0 else "and" + print(f"{typeOrAnd} {t.newName} =") + if t.isVariant: + for c in t.ctorsOrFields: + print(f" | {c.newName} of {c.newType}") else: - print " | {" - for (f, builtin, ft,) in ctors: - print (" %s;" % (f,)) - print " } ->" - for (f, builtin, ft,) in ctors: - ff = ("%s'" % (f,)) - print (" let (%s, state) = visitor.%s_%s %s state continue_fold in" % (f, t, f, f,)) - if not builtin: - print (" let (%s, state) = visitor.%s %s state continue_fold in" % (ff, ft, f,)) - print " ({" - for (f, builtin, ft,) in ctors: - ff = ("%s'" % (f,)) - print (" %s;" % (ff,)) - print " }, state)" - print "" + print(" {") + for f in t.ctorsOrFields: + print(f" {f.newName} : {f.newType} ;") + print(" }") + + +# print("") +# print("type 'state continue_fold =") +# print(" {") +# for t in adts: +# print(f" {t.name} : {t.name} -> 'state -> ({t.newName} * 'state) ;") +# print(" }") + +def folder(name, extraArgs): + print("") + print(f"type 'state {name} =") + print(" {") + for t in adts: + print(f" {t.name} : {t.name} -> 'state{extraArgs} -> ({t.newName} * 'state) ;") + for c in t.ctorsOrFields: + print(f" {t.name}_{c.name} : {c.type_} -> 'state{extraArgs} -> ({c.newType} * 'state) ;") + print(" }") + +folder("continue_fold", "") +folder("fold_config", " -> ('state continue_fold)") + +print("") +print('(* Curries the "visitor" argument to the folds (non-customizable traversal functions). *)') +print("let rec mk_continue_fold : type state . state fold_config -> state continue_fold = fun visitor ->") +print(" {") +for t in adts: + print(f" {t.name} = fold_{t.name} visitor ;") + for c in t.ctorsOrFields: + print(f" {t.name}_{c.name} = fold_{t.name}_{c.name} visitor ;") +print("}") +print("") + +for t in adts: + print(f"and fold_{t.name} : type state . state fold_config -> {t.name} -> state -> ({t.newName} * state) = fun visitor x state ->") + print(" let continue_fold : state continue_fold = mk_continue_fold visitor in") + print(f" visitor.{t.name} x state continue_fold") + print("") + for c in t.ctorsOrFields: + print(f"and fold_{t.name}_{c.name} : type state . state fold_config -> {c.type_} -> state -> ({c.newType} * state) = fun visitor x state ->") + print(" let continue_fold : state continue_fold = mk_continue_fold visitor in") + print(f" visitor.{t.name}_{c.name} x state continue_fold") + print("") + + # print(" match x with") + # if t.isVariant: + # for c in t.ctorsOrFields: + # print(f" | {c.name} v ->") + # print(f" let (v', state) = visitor.{t.name}_{c.name} v state continue_fold in") + # print(f" ({c.newName} v', state)") + # else: + # print(" | {", end=' ') + # for f in t.ctorsOrFields: + # print(f"{f.name};", end=' ') + # print("} ->") + # for f in t.ctorsOrFields: + # print(f" let ({f.newName}, state) = visitor.{t.name}_{f.name} {f.name} state continue_fold in") + # print(" ({", end=' ') + # for f in t.ctorsOrFields: + # print(f"{f.newName};", end=' ') + # print("}, state)") + # print("") + # for c in t.ctorsOrFields: + # print(f"and fold_{t.name}_{c.name} : type state . state fold_config -> {c.type_} -> state -> ({c.newType} * state) = fun visitor x state ->") + # if c.isBuiltin: + # print(" ignore visitor; (x, state)") + # else: + # print(" let continue_fold : state continue_fold = mk_continue_fold visitor in") + # print(f" visitor.{c.type_} x state continue_fold") + # print("") + +# print """let no_op : ('a -> unit) -> 'a fold_config = fun phantom -> failwith "todo" """ + +print("let no_op : 'a fold_config = {") +for t in adts: + print(f" {t.name} = (fun v state continue ->") + print(" match v with") + if t.isVariant: + for c in t.ctorsOrFields: + print(f" | {c.name} v -> let (v, state) = continue.{t.name}_{c.name} v state in ({c.newName} v, state)") + else: + print(" {", end=' ') + for f in t.ctorsOrFields: + print(f"{f.name};", end=' ') + print("} ->") + for f in t.ctorsOrFields: + print(f" let ({f.newName}, state) = continue.{t.name}_{f.name} {f.name} state in") + print(" ({", end=' ') + for f in t.ctorsOrFields: + print(f"{f.newName};", end=' ') + print("}, state)") + print(" );") + for c in t.ctorsOrFields: + print(f" {t.name}_{c.name} = (fun v state continue ->", end=' ') + if c.isBuiltin: + print("ignore continue; (v, state)", end=' ') + else: + print(f"continue.{c.type_} v state", end=' ') + print(") ;") +print("}") + + + # (fun v state continue -> + # let (new_v, new_state) = match v with + # | A v -> let (v, state) = continue.a v state in (A' v, state) + # | B v -> let (v, state) = (fun x s -> (x,s)) v state in (B' v, state) + # | C v -> let (v, state) = (fun x s -> (x,s)) v state in (C' v, state) + # in + # (new_v, new_state) + # ); + + + + + + + # if not builtin: + # print (" let (v', state) = match v' with None -> visitor.%s v state continue_fold | Some v' -> (v', state) in" % (ct,)) + # else: + # print " let Some v' = v' in" + + # if not builtin: + # print (" let (%s, state) = match %s with None -> visitor.%s %s state continue_fold | Some v' -> (v', state) in" % (ff, ff, ft, f)) + # else: + # print " let Some v' = v' in" diff --git a/src/stages/adt_generator/use_a_fold.ml b/src/stages/adt_generator/use_a_fold.ml index 03f4cca3e..13f78e040 100644 --- a/src/stages/adt_generator/use_a_fold.ml +++ b/src/stages/adt_generator/use_a_fold.ml @@ -2,7 +2,7 @@ open A open Fold let _ = - let some_root = ((failwith "assume we have some root") : root) in + let some_root : root = A { a1 = X (A { a1 = X (B 1) ; a2 = W () ; }) ; a2 = Z (W ()) ; } in let op = { no_op with a = fun the_a state continue_fold -> @@ -11,8 +11,12 @@ let _ = ({ a1' = a1' ; a2' = a2' ; - }, state'') + }, state'' + 1) } in - let state = () in - fold_root op some_root state + let state = 0 in + let (_, state) = fold_root op some_root state in + Printf.printf "trilili %d" state + +let _noi : int fold_config = no_op (* (fun _ -> ()) *) +let _nob : bool fold_config = no_op (* (fun _ -> ()) *)