Shell: improve Lwt_utils

This commit is contained in:
Vincent Bernardoff 2017-01-14 13:13:27 +01:00 committed by Grégoire Henry
parent 0027d70158
commit 6b3e002285
4 changed files with 153 additions and 14 deletions

View File

@ -581,7 +581,7 @@ module Make (P: PARAMS) = struct
"(%a) connection handler for %a:%d"
pp_gid my_gid Ipaddr.pp_hum addr port in
ignore (Lwt_utils.worker worker_name
~safe:true ~run:(fun () -> connect buf) ~cancel) ;
~run:(fun () -> connect buf) ~cancel) ;
(* return the canceler *)
cancel

View File

@ -7,7 +7,6 @@
(* *)
(**************************************************************************)
exception Exit
let termination_thread, exit_wakener = Lwt.wait ()
@ -18,6 +17,12 @@ let () =
(function
| Exit -> ()
| exn ->
Printf.eprintf "Uncaught (asynchronous) exception: %S\n%s\n%!"
(Printexc.to_string exn) (Printexc.get_backtrace ());
Format.eprintf
"@[Uncaught (asynchronous) exception (%d):@ %a@]"
(Unix.getpid ())
Error_monad.pp_exn exn ;
let backtrace = Printexc.get_backtrace () in
if String.length backtrace <> 0 then
Format.eprintf "\n%s" backtrace ;
Format.eprintf "@." ;
Lwt.wakeup exit_wakener 1)

View File

@ -12,7 +12,7 @@ module LC = Lwt_condition
open Lwt.Infix
open Logging.Core
let may f = function
let may ~f = function
| None -> Lwt.return_unit
| Some x -> f x
@ -39,10 +39,13 @@ let canceler ()
else begin
canceling := true ;
LC.broadcast cancelation () ;
!cancel_hook () >>= fun () ->
canceled := true ;
LC.broadcast cancelation_complete () ;
Lwt.return ()
Lwt.finalize
!cancel_hook
(fun () ->
canceled := true ;
LC.broadcast cancelation_complete () ;
Lwt.return ()) >>= fun () ->
Lwt.return_unit
end
in
let on_cancel cb =
@ -55,6 +58,53 @@ let canceler ()
in
cancelation, cancel, on_cancel
module Canceler = struct
type t = {
cancelation: unit Lwt_condition.t ;
cancelation_complete: unit Lwt_condition.t ;
mutable cancel_hook: unit -> unit Lwt.t ;
mutable canceling: bool ;
mutable canceled: bool ;
}
let create () =
let cancelation = LC.create () in
let cancelation_complete = LC.create () in
{ cancelation ; cancelation_complete ;
cancel_hook = (fun () -> Lwt.return ()) ;
canceling = false ;
canceled = false ;
}
let cancel st =
if st.canceled then
Lwt.return ()
else if st.canceling then
LC.wait st.cancelation_complete
else begin
st.canceling <- true ;
LC.broadcast st.cancelation () ;
Lwt.finalize
st.cancel_hook
(fun () ->
st.canceled <- true ;
LC.broadcast st.cancelation_complete () ;
Lwt.return ())
end
let on_cancel st cb =
let hook = st.cancel_hook in
st.cancel_hook <- (fun () -> hook () >>= cb)
let cancelation st =
if st.canceling then Lwt.return ()
else LC.wait st.cancelation
let canceled st = st.canceling
end
type trigger =
| Absent
| Present
@ -114,12 +164,11 @@ let queue () : ('a -> unit) * (unit -> 'a list Lwt.t) =
queue, wait
(* A worker launcher, takes a cancel callback to call upon *)
let worker ?(safe=false) name ~run ~cancel =
let worker name ~run ~cancel =
let stop = LC.create () in
let fail e =
log_error "%s worker failed with %s" name (Printexc.to_string e) ;
cancel () >>= fun () ->
if safe then Lwt.return_unit else Lwt.fail e
cancel ()
in
let waiter = LC.wait stop in
log_info "%s worker started" name ;
@ -263,6 +312,17 @@ let write_mbytes ?(pos=0) ?len descr buf =
| nb_written -> inner (pos + nb_written) (len - nb_written) in
inner pos len
let write_bytes ?(pos=0) ?len descr buf =
let len = match len with None -> Bytes.length buf - pos | Some l -> l in
let rec inner pos len =
if len = 0 then
Lwt.return_unit
else
Lwt_unix.write descr buf pos len >>= function
| 0 -> Lwt.fail End_of_file (* other endpoint cleanly closed its connection *)
| nb_written -> inner (pos + nb_written) (len - nb_written) in
inner pos len
let (>>=) = Lwt.bind
let remove_dir dir =
@ -297,3 +357,49 @@ let create_file ?(perm = 0o644) name content =
Lwt_unix.openfile name Unix.([O_TRUNC; O_CREAT; O_WRONLY]) perm >>= fun fd ->
Lwt_unix.write_string fd content 0 (String.length content) >>= fun _ ->
Lwt_unix.close fd
let safe_close fd =
Lwt.catch
(fun () -> Lwt_unix.close fd)
(fun _ -> Lwt.return_unit)
open Error_monad
type error += Canceled
let protect ?on_error ?canceler t =
let cancelation =
match canceler with
| None -> never_ending
| Some canceler ->
( Canceler.cancelation canceler >>= fun () ->
fail Canceled ) in
let res =
Lwt.pick [ cancelation ;
Lwt.catch t (fun exn -> fail (Exn exn)) ] in
res >>= function
| Ok _ -> res
| Error err ->
let canceled =
Utils.unopt_map canceler ~default:false ~f:Canceler.canceled in
let err = if canceled then [Canceled] else err in
match on_error with
| None -> Lwt.return (Error err)
| Some on_error -> on_error err
type error += Timeout
let with_timeout ?(canceler = Canceler.create ()) timeout f =
let t = Lwt_unix.sleep timeout in
Lwt.choose [
(t >|= fun () -> None) ;
(f canceler >|= fun x -> Some x)
] >>= function
| Some x when Lwt.state t = Lwt.Sleep ->
Lwt.cancel t ;
Lwt.return x
| _ ->
Canceler.cancel canceler >>= fun () ->
fail Timeout

View File

@ -7,7 +7,7 @@
(* *)
(**************************************************************************)
val may : ('a -> unit Lwt.t) -> 'a option -> unit Lwt.t
val may: f:('a -> unit Lwt.t) -> 'a option -> unit Lwt.t
val never_ending: 'a Lwt.t
@ -16,8 +16,18 @@ val canceler : unit ->
(unit -> unit Lwt.t) *
((unit -> unit Lwt.t) -> unit)
module Canceler : sig
type t
val create : unit -> t
val cancel : t -> unit Lwt.t
val cancelation : t -> unit Lwt.t
val on_cancel : t -> (unit -> unit Lwt.t) -> unit
val canceled : t -> bool
end
val worker:
?safe:bool ->
string ->
run:(unit -> unit Lwt.t) ->
cancel:(unit -> unit Lwt.t) ->
@ -33,9 +43,27 @@ val read_bytes:
val read_mbytes:
?pos:int -> ?len:int -> Lwt_unix.file_descr -> MBytes.t -> unit Lwt.t
val write_bytes:
?pos:int -> ?len:int -> Lwt_unix.file_descr -> bytes -> unit Lwt.t
val write_mbytes:
?pos:int -> ?len:int -> Lwt_unix.file_descr -> MBytes.t -> unit Lwt.t
val remove_dir: string -> unit Lwt.t
val create_dir: ?perm:int -> string -> unit Lwt.t
val create_file: ?perm:int -> string -> string -> unit Lwt.t
val safe_close: Lwt_unix.file_descr -> unit Lwt.t
open Error_monad
type error += Canceled
val protect :
?on_error:(error list -> 'a tzresult Lwt.t) ->
?canceler:Canceler.t ->
(unit -> 'a tzresult Lwt.t) -> 'a tzresult Lwt.t
type error += Timeout
val with_timeout:
?canceler:Canceler.t ->
float -> (Canceler.t -> 'a tzresult Lwt.t) -> 'a tzresult Lwt.t