Lwt_utils: connect: add a ?timeout argument

This is to replace the default UNIX timeout that can be very long (30s
or so). In the context of baking, it is not acceptable to wait for
such a long time to connect to e.g. a signer daemon whenever there is
multiple addresses available for load balancing.
This commit is contained in:
Vincent Bernardoff 2018-09-04 00:23:51 +09:00 committed by Grégoire Henry
parent cc848fc479
commit f458b6119c
No known key found for this signature in database
GPG Key ID: 827A020B224844F1
2 changed files with 28 additions and 15 deletions

View File

@ -288,7 +288,7 @@ module Socket = struct
| Error (`Msg _) -> host
| Ok ipaddr -> Ipaddr.to_string ipaddr
let connect = function
let connect ?(timeout=5.) = function
| Unix path ->
let addr = Lwt_unix.ADDR_UNIX path in
let sock = Lwt_unix.socket PF_UNIX SOCK_STREAM 0 in
@ -300,22 +300,24 @@ module Socket = struct
| [] ->
failwith "could not resolve host '%s'" host
| addrs ->
let rec try_connect = function
let rec try_connect acc = function
| [] ->
failwith "could not connect to '%s'" host
Lwt.return
(Error (failure "could not connect to '%s'" host :: List.rev acc))
| { Unix.ai_family; ai_socktype; ai_protocol; ai_addr } :: addrs ->
let sock = Lwt_unix.socket ai_family ai_socktype ai_protocol in
Lwt.catch
(fun () ->
protect ~on_error:begin fun e ->
Lwt_unix.close sock >>= fun () ->
Lwt.return (Error e)
end begin fun () ->
with_timeout (Lwt_unix.sleep timeout) (fun _c ->
Lwt_unix.connect sock ai_addr >>= fun () ->
return sock)
(fun exn ->
Format.printf "@{<error>@{<title>Unable to connect to %s@}@}@.\
\ @[<h 0>%a@]@."
host Format.pp_print_text (Printexc.to_string exn) ;
Lwt_unix.close sock >>= fun () ->
try_connect addrs) in
try_connect addrs
end >>= function
| Ok sock -> return sock
| Error e ->
try_connect (e @ acc) addrs in
try_connect [] addrs
let bind ?(backlog = 10) = function
| Unix path ->

View File

@ -78,7 +78,18 @@ module Socket : sig
| Unix of string
| Tcp of string * string * Unix.getaddrinfo_option list
val connect: addr -> Lwt_unix.file_descr tzresult Lwt.t
val connect:
?timeout:float -> addr -> Lwt_unix.file_descr tzresult Lwt.t
(** [connect ?timeout addr] tries connecting to [addr] and returns
the resulting socket file descriptor on success. When using TCP,
[Unix.getaddrinfo] is used to resolve the hostname and service
(port). The different socket addresses returned by
[Unix.getaddrinfo] are tried sequentially, and the [?timeout]
argument (default: 5s) governs how long it waits to get a
connection. If a connection is not obtained in less than
[?timeout], the connection is canceled and and the next socket
address (if it exists) is tried. *)
val bind:
?backlog:int -> addr -> Lwt_unix.file_descr list tzresult Lwt.t