diff --git a/src/lib_stdlib_unix/lwt_utils_unix.ml b/src/lib_stdlib_unix/lwt_utils_unix.ml index 6326288a2..2225e93f4 100644 --- a/src/lib_stdlib_unix/lwt_utils_unix.ml +++ b/src/lib_stdlib_unix/lwt_utils_unix.ml @@ -280,56 +280,62 @@ module Socket = struct type addr = | Unix of string - | Tcp of string * int + | Tcp of string * string * Unix.getaddrinfo_option list - let get_addrs host = - try return (Array.to_list (Unix.gethostbyname host).h_addr_list) - with Not_found -> failwith "Host %s not found" host + let handle_litteral_ipv6 host = + (* To strip '[' and ']' when a litteral IPv6 is provided *) + match Ipaddr.of_string host with + | None -> host + | Some ipaddr -> Ipaddr.to_string ipaddr - let connect path = - match path with + let connect = function | Unix path -> let addr = Lwt_unix.ADDR_UNIX path in let sock = Lwt_unix.socket PF_UNIX SOCK_STREAM 0 in Lwt_unix.connect sock addr >>= fun () -> return sock - | Tcp (host, port) -> - get_addrs host >>=? fun addrs -> - let rec try_connect = function - | [] -> failwith "could not resolve host '%s'" host - | addr :: addrs -> - Lwt.catch - (fun () -> - let addr = Lwt_unix.ADDR_INET (addr, port) in - let sock = Lwt_unix.socket PF_INET SOCK_STREAM 0 in - Lwt_unix.connect sock addr >>= fun () -> - return sock) - (fun _ -> try_connect addrs) in - try_connect addrs + | Tcp (host, service, opts) -> + let host = handle_litteral_ipv6 host in + Lwt_unix.getaddrinfo host service opts >>= function + | [] -> + failwith "could not resolve host '%s'" host + | addrs -> + let rec try_connect = function + | [] -> + failwith "could not connect to '%s'" host + | { Unix.ai_family; ai_socktype; ai_protocol; ai_addr } :: addrs -> + let sock = Lwt_unix.socket ai_family ai_socktype ai_protocol in + Lwt.catch + (fun () -> + Lwt_unix.connect sock ai_addr >>= fun () -> + return sock) + (fun exn -> + Format.printf "@{@{Unable to connect to %s@}@}@.\ + \ @[<h 0>%a@]@." + host Format.pp_print_text (Printexc.to_string exn) ; + Lwt_unix.close sock >>= fun () -> + try_connect addrs) in + try_connect addrs - let bind ?(backlog = 10) path = - match path with + let bind ?(backlog = 10) = function | Unix path -> let addr = Lwt_unix.ADDR_UNIX path in let sock = Lwt_unix.socket PF_UNIX SOCK_STREAM 0 in Lwt_unix.bind sock addr >>= fun () -> Lwt_unix.listen sock backlog ; - return sock - | Tcp (host, port) -> - get_addrs host >>=? fun addrs -> - let rec try_bind = function - | [] -> failwith "could not resolve host '%s'" host - | addr :: addrs -> - Lwt.catch - (fun () -> - let addr = Lwt_unix.ADDR_INET (addr, port) in - let sock = Lwt_unix.socket PF_INET SOCK_STREAM 0 in - Lwt_unix.setsockopt sock SO_REUSEADDR true ; - Lwt_unix.bind sock addr >>= fun () -> - Lwt_unix.listen sock backlog ; - return sock) - (fun _ -> try_bind addrs) in - try_bind addrs + return [sock] + | Tcp (host, service, opts) -> + Lwt_unix.getaddrinfo + (handle_litteral_ipv6 host) service (AI_PASSIVE :: opts) >>= function + | [] -> failwith "could not resolve host '%s'" host + | addrs -> + let do_bind { Unix.ai_family; ai_socktype; ai_protocol; ai_addr } = + let sock = Lwt_unix.socket ai_family ai_socktype ai_protocol in + Lwt_unix.setsockopt sock SO_REUSEADDR true ; + Lwt_unix.bind sock ai_addr >>= fun () -> + Lwt_unix.listen sock backlog ; + return sock in + map_s do_bind addrs type error += | Encoding_error diff --git a/src/lib_stdlib_unix/lwt_utils_unix.mli b/src/lib_stdlib_unix/lwt_utils_unix.mli index b8e0d2abd..9058e8126 100644 --- a/src/lib_stdlib_unix/lwt_utils_unix.mli +++ b/src/lib_stdlib_unix/lwt_utils_unix.mli @@ -76,10 +76,11 @@ module Socket : sig type addr = | Unix of string - | Tcp of string * int + | Tcp of string * string * Unix.getaddrinfo_option list val connect: addr -> Lwt_unix.file_descr tzresult Lwt.t - val bind: ?backlog:int -> addr -> Lwt_unix.file_descr tzresult Lwt.t + val bind: + ?backlog:int -> addr -> Lwt_unix.file_descr list tzresult Lwt.t type error += | Encoding_error