diff --git a/src/node/net/p2p.ml b/src/node/net/p2p.ml index 9b7f9875d..6c7aafda0 100644 --- a/src/node/net/p2p.ml +++ b/src/node/net/p2p.ml @@ -581,7 +581,7 @@ module Make (P: PARAMS) = struct "(%a) connection handler for %a:%d" pp_gid my_gid Ipaddr.pp_hum addr port in ignore (Lwt_utils.worker worker_name - ~safe:true ~run:(fun () -> connect buf) ~cancel) ; + ~run:(fun () -> connect buf) ~cancel) ; (* return the canceler *) cancel diff --git a/src/utils/lwt_exit.ml b/src/utils/lwt_exit.ml index 2e82b1fe8..09481e875 100644 --- a/src/utils/lwt_exit.ml +++ b/src/utils/lwt_exit.ml @@ -7,7 +7,6 @@ (* *) (**************************************************************************) - exception Exit let termination_thread, exit_wakener = Lwt.wait () @@ -18,6 +17,12 @@ let () = (function | Exit -> () | exn -> - Printf.eprintf "Uncaught (asynchronous) exception: %S\n%s\n%!" - (Printexc.to_string exn) (Printexc.get_backtrace ()); + Format.eprintf + "@[Uncaught (asynchronous) exception (%d):@ %a@]" + (Unix.getpid ()) + Error_monad.pp_exn exn ; + let backtrace = Printexc.get_backtrace () in + if String.length backtrace <> 0 then + Format.eprintf "\n%s" backtrace ; + Format.eprintf "@." ; Lwt.wakeup exit_wakener 1) diff --git a/src/utils/lwt_utils.ml b/src/utils/lwt_utils.ml index d3ad37d5b..00f857ab9 100644 --- a/src/utils/lwt_utils.ml +++ b/src/utils/lwt_utils.ml @@ -12,7 +12,7 @@ module LC = Lwt_condition open Lwt.Infix open Logging.Core -let may f = function +let may ~f = function | None -> Lwt.return_unit | Some x -> f x @@ -39,10 +39,13 @@ let canceler () else begin canceling := true ; LC.broadcast cancelation () ; - !cancel_hook () >>= fun () -> - canceled := true ; - LC.broadcast cancelation_complete () ; - Lwt.return () + Lwt.finalize + !cancel_hook + (fun () -> + canceled := true ; + LC.broadcast cancelation_complete () ; + Lwt.return ()) >>= fun () -> + Lwt.return_unit end in let on_cancel cb = @@ -55,6 +58,53 @@ let canceler () in cancelation, cancel, on_cancel +module Canceler = struct + + type t = { + cancelation: unit Lwt_condition.t ; + cancelation_complete: unit Lwt_condition.t ; + mutable cancel_hook: unit -> unit Lwt.t ; + mutable canceling: bool ; + mutable canceled: bool ; + } + + let create () = + let cancelation = LC.create () in + let cancelation_complete = LC.create () in + { cancelation ; cancelation_complete ; + cancel_hook = (fun () -> Lwt.return ()) ; + canceling = false ; + canceled = false ; + } + + let cancel st = + if st.canceled then + Lwt.return () + else if st.canceling then + LC.wait st.cancelation_complete + else begin + st.canceling <- true ; + LC.broadcast st.cancelation () ; + Lwt.finalize + st.cancel_hook + (fun () -> + st.canceled <- true ; + LC.broadcast st.cancelation_complete () ; + Lwt.return ()) + end + + let on_cancel st cb = + let hook = st.cancel_hook in + st.cancel_hook <- (fun () -> hook () >>= cb) + + let cancelation st = + if st.canceling then Lwt.return () + else LC.wait st.cancelation + + let canceled st = st.canceling + +end + type trigger = | Absent | Present @@ -114,12 +164,11 @@ let queue () : ('a -> unit) * (unit -> 'a list Lwt.t) = queue, wait (* A worker launcher, takes a cancel callback to call upon *) -let worker ?(safe=false) name ~run ~cancel = +let worker name ~run ~cancel = let stop = LC.create () in let fail e = log_error "%s worker failed with %s" name (Printexc.to_string e) ; - cancel () >>= fun () -> - if safe then Lwt.return_unit else Lwt.fail e + cancel () in let waiter = LC.wait stop in log_info "%s worker started" name ; @@ -263,6 +312,17 @@ let write_mbytes ?(pos=0) ?len descr buf = | nb_written -> inner (pos + nb_written) (len - nb_written) in inner pos len +let write_bytes ?(pos=0) ?len descr buf = + let len = match len with None -> Bytes.length buf - pos | Some l -> l in + let rec inner pos len = + if len = 0 then + Lwt.return_unit + else + Lwt_unix.write descr buf pos len >>= function + | 0 -> Lwt.fail End_of_file (* other endpoint cleanly closed its connection *) + | nb_written -> inner (pos + nb_written) (len - nb_written) in + inner pos len + let (>>=) = Lwt.bind let remove_dir dir = @@ -297,3 +357,49 @@ let create_file ?(perm = 0o644) name content = Lwt_unix.openfile name Unix.([O_TRUNC; O_CREAT; O_WRONLY]) perm >>= fun fd -> Lwt_unix.write_string fd content 0 (String.length content) >>= fun _ -> Lwt_unix.close fd + +let safe_close fd = + Lwt.catch + (fun () -> Lwt_unix.close fd) + (fun _ -> Lwt.return_unit) + +open Error_monad + +type error += Canceled + +let protect ?on_error ?canceler t = + let cancelation = + match canceler with + | None -> never_ending + | Some canceler -> + ( Canceler.cancelation canceler >>= fun () -> + fail Canceled ) in + let res = + Lwt.pick [ cancelation ; + Lwt.catch t (fun exn -> fail (Exn exn)) ] in + res >>= function + | Ok _ -> res + | Error err -> + let canceled = + Utils.unopt_map canceler ~default:false ~f:Canceler.canceled in + let err = if canceled then [Canceled] else err in + match on_error with + | None -> Lwt.return (Error err) + | Some on_error -> on_error err + +type error += Timeout + +let with_timeout ?(canceler = Canceler.create ()) timeout f = + let t = Lwt_unix.sleep timeout in + Lwt.choose [ + (t >|= fun () -> None) ; + (f canceler >|= fun x -> Some x) + ] >>= function + | Some x when Lwt.state t = Lwt.Sleep -> + Lwt.cancel t ; + Lwt.return x + | _ -> + Canceler.cancel canceler >>= fun () -> + fail Timeout + + diff --git a/src/utils/lwt_utils.mli b/src/utils/lwt_utils.mli index 0fd73d6cd..78cf995a2 100644 --- a/src/utils/lwt_utils.mli +++ b/src/utils/lwt_utils.mli @@ -7,7 +7,7 @@ (* *) (**************************************************************************) -val may : ('a -> unit Lwt.t) -> 'a option -> unit Lwt.t +val may: f:('a -> unit Lwt.t) -> 'a option -> unit Lwt.t val never_ending: 'a Lwt.t @@ -16,8 +16,18 @@ val canceler : unit -> (unit -> unit Lwt.t) * ((unit -> unit Lwt.t) -> unit) +module Canceler : sig + + type t + val create : unit -> t + val cancel : t -> unit Lwt.t + val cancelation : t -> unit Lwt.t + val on_cancel : t -> (unit -> unit Lwt.t) -> unit + val canceled : t -> bool + +end + val worker: - ?safe:bool -> string -> run:(unit -> unit Lwt.t) -> cancel:(unit -> unit Lwt.t) -> @@ -33,9 +43,27 @@ val read_bytes: val read_mbytes: ?pos:int -> ?len:int -> Lwt_unix.file_descr -> MBytes.t -> unit Lwt.t +val write_bytes: + ?pos:int -> ?len:int -> Lwt_unix.file_descr -> bytes -> unit Lwt.t val write_mbytes: ?pos:int -> ?len:int -> Lwt_unix.file_descr -> MBytes.t -> unit Lwt.t val remove_dir: string -> unit Lwt.t val create_dir: ?perm:int -> string -> unit Lwt.t val create_file: ?perm:int -> string -> string -> unit Lwt.t + +val safe_close: Lwt_unix.file_descr -> unit Lwt.t + +open Error_monad + +type error += Canceled +val protect : + ?on_error:(error list -> 'a tzresult Lwt.t) -> + ?canceler:Canceler.t -> + (unit -> 'a tzresult Lwt.t) -> 'a tzresult Lwt.t + +type error += Timeout +val with_timeout: + ?canceler:Canceler.t -> + float -> (Canceler.t -> 'a tzresult Lwt.t) -> 'a tzresult Lwt.t +