Utils/Socket: use getaddrinfo instead of gethostbyname

This commit is contained in:
Vincent Bernardoff 2018-07-15 21:49:02 +02:00 committed by Benjamin Canou
parent b76b5367bb
commit fe21585462
2 changed files with 46 additions and 39 deletions

View File

@ -280,56 +280,62 @@ module Socket = struct
type addr =
| Unix of string
| Tcp of string * int
| Tcp of string * string * Unix.getaddrinfo_option list
let get_addrs host =
try return (Array.to_list (Unix.gethostbyname host).h_addr_list)
with Not_found -> failwith "Host %s not found" host
let handle_litteral_ipv6 host =
(* To strip '[' and ']' when a litteral IPv6 is provided *)
match Ipaddr.of_string host with
| None -> host
| Some ipaddr -> Ipaddr.to_string ipaddr
let connect path =
match path with
let connect = function
| Unix path ->
let addr = Lwt_unix.ADDR_UNIX path in
let sock = Lwt_unix.socket PF_UNIX SOCK_STREAM 0 in
Lwt_unix.connect sock addr >>= fun () ->
return sock
| Tcp (host, port) ->
get_addrs host >>=? fun addrs ->
let rec try_connect = function
| [] -> failwith "could not resolve host '%s'" host
| addr :: addrs ->
Lwt.catch
(fun () ->
let addr = Lwt_unix.ADDR_INET (addr, port) in
let sock = Lwt_unix.socket PF_INET SOCK_STREAM 0 in
Lwt_unix.connect sock addr >>= fun () ->
return sock)
(fun _ -> try_connect addrs) in
try_connect addrs
| Tcp (host, service, opts) ->
let host = handle_litteral_ipv6 host in
Lwt_unix.getaddrinfo host service opts >>= function
| [] ->
failwith "could not resolve host '%s'" host
| addrs ->
let rec try_connect = function
| [] ->
failwith "could not connect to '%s'" host
| { Unix.ai_family; ai_socktype; ai_protocol; ai_addr } :: addrs ->
let sock = Lwt_unix.socket ai_family ai_socktype ai_protocol in
Lwt.catch
(fun () ->
Lwt_unix.connect sock ai_addr >>= fun () ->
return sock)
(fun exn ->
Format.printf "@{<error>@{<title>Unable to connect to %s@}@}@.\
\ @[<h 0>%a@]@."
host Format.pp_print_text (Printexc.to_string exn) ;
Lwt_unix.close sock >>= fun () ->
try_connect addrs) in
try_connect addrs
let bind ?(backlog = 10) path =
match path with
let bind ?(backlog = 10) = function
| Unix path ->
let addr = Lwt_unix.ADDR_UNIX path in
let sock = Lwt_unix.socket PF_UNIX SOCK_STREAM 0 in
Lwt_unix.bind sock addr >>= fun () ->
Lwt_unix.listen sock backlog ;
return sock
| Tcp (host, port) ->
get_addrs host >>=? fun addrs ->
let rec try_bind = function
| [] -> failwith "could not resolve host '%s'" host
| addr :: addrs ->
Lwt.catch
(fun () ->
let addr = Lwt_unix.ADDR_INET (addr, port) in
let sock = Lwt_unix.socket PF_INET SOCK_STREAM 0 in
Lwt_unix.setsockopt sock SO_REUSEADDR true ;
Lwt_unix.bind sock addr >>= fun () ->
Lwt_unix.listen sock backlog ;
return sock)
(fun _ -> try_bind addrs) in
try_bind addrs
return [sock]
| Tcp (host, service, opts) ->
Lwt_unix.getaddrinfo
(handle_litteral_ipv6 host) service (AI_PASSIVE :: opts) >>= function
| [] -> failwith "could not resolve host '%s'" host
| addrs ->
let do_bind { Unix.ai_family; ai_socktype; ai_protocol; ai_addr } =
let sock = Lwt_unix.socket ai_family ai_socktype ai_protocol in
Lwt_unix.setsockopt sock SO_REUSEADDR true ;
Lwt_unix.bind sock ai_addr >>= fun () ->
Lwt_unix.listen sock backlog ;
return sock in
map_s do_bind addrs
type error +=
| Encoding_error

View File

@ -76,10 +76,11 @@ module Socket : sig
type addr =
| Unix of string
| Tcp of string * int
| Tcp of string * string * Unix.getaddrinfo_option list
val connect: addr -> Lwt_unix.file_descr tzresult Lwt.t
val bind: ?backlog:int -> addr -> Lwt_unix.file_descr tzresult Lwt.t
val bind:
?backlog:int -> addr -> Lwt_unix.file_descr list tzresult Lwt.t
type error +=
| Encoding_error