diff --git a/src/utils/error_monad.ml b/src/utils/error_monad.ml index f689b8410..8ff3375ae 100644 --- a/src/utils/error_monad.ml +++ b/src/utils/error_monad.ml @@ -174,6 +174,11 @@ module Make() = struct let fail s = Lwt.return (Error [ s ]) + let protect ~on_error t = + t >>= function + | Ok res -> return res + | Error err -> on_error err + let (>>?) v f = match v with | Error _ as err -> err @@ -286,6 +291,9 @@ module Make() = struct let fail_unless cond exn = if cond then return () else fail exn + let unless cond f = + if cond then return () else f () + let pp_print_error ppf errors = Format.fprintf ppf "@[Error, dumping error stack:@,%a@]@." (Format.pp_print_list pp) @@ -332,15 +340,20 @@ let error_exn s = Error [ Exn s ] let trace_exn exn f = trace (Exn exn) f let record_trace_exn exn f = record_trace (Exn exn) f +let pp_exn ppf exn = pp ppf (Exn exn) + let () = register_error_kind `Temporary ~id:"failure" ~title:"Generic error" ~description:"Unclassified error" + ~pp:Format.pp_print_string Data_encoding.(obj1 (req "msg" string)) (function | Exn (Failure msg) -> Some msg + | Exn (Unix.Unix_error (err, fn, _)) -> + Some ("Unix error in " ^ fn ^ ": " ^ Unix.error_message err) | Exn exn -> Some (Printexc.to_string exn) | _ -> None) (fun msg -> Exn (Failure msg)) diff --git a/src/utils/error_monad.mli b/src/utils/error_monad.mli index 4b3f0e1b4..11e607101 100644 --- a/src/utils/error_monad.mli +++ b/src/utils/error_monad.mli @@ -29,6 +29,7 @@ val failwith : val error_exn : exn -> 'a tzresult val record_trace_exn : exn -> 'a tzresult -> 'a tzresult val trace_exn : exn -> 'b tzresult Lwt.t -> 'b tzresult Lwt.t +val pp_exn : Format.formatter -> exn -> unit type error += Exn of exn type error += Unclassified of string diff --git a/src/utils/error_monad_sig.ml b/src/utils/error_monad_sig.ml index 02964ae79..493d3f000 100644 --- a/src/utils/error_monad_sig.ml +++ b/src/utils/error_monad_sig.ml @@ -100,6 +100,12 @@ module type S = sig (** Erroneous return on failed assertion *) val fail_unless : bool -> error -> unit tzresult Lwt.t + val unless : bool -> (unit -> unit tzresult Lwt.t) -> unit tzresult Lwt.t + + val protect : + on_error: (error list -> 'a tzresult Lwt.t) -> + 'a tzresult Lwt.t -> 'a tzresult Lwt.t + (** {2 In-monad list iterators} ********************************************) (** A {!List.iter} in the monad *)