500 lines
16 KiB
OCaml
500 lines
16 KiB
OCaml
(**************************************************************************)
|
|
(* *)
|
|
(* Copyright (c) 2014 - 2017. *)
|
|
(* Dynamic Ledger Solutions, Inc. <contact@tezos.com> *)
|
|
(* *)
|
|
(* All rights reserved. No warranty, explicit or implicit, provided. *)
|
|
(* *)
|
|
(**************************************************************************)
|
|
|
|
(* TODO decide whether we need to preallocate buffers or not. *)
|
|
|
|
open P2p_types
|
|
include Logging.Make (struct let name = "p2p.io-scheduler" end)
|
|
|
|
module Inttbl = Hashtbl.Make(struct
|
|
type t = int
|
|
let equal (x: int) (y: int) = x = y
|
|
let hash = Hashtbl.hash
|
|
end)
|
|
|
|
let alpha = 0.2
|
|
|
|
module type IO = sig
|
|
val name: string
|
|
type in_param
|
|
val pop: in_param -> MBytes.t tzresult Lwt.t
|
|
type out_param
|
|
val push: out_param -> MBytes.t -> unit tzresult Lwt.t
|
|
val close: out_param -> error list -> unit Lwt.t
|
|
end
|
|
|
|
type error += Connection_closed
|
|
|
|
module Scheduler(IO : IO) = struct
|
|
|
|
type t = {
|
|
canceler: Canceler.t ;
|
|
mutable worker: unit Lwt.t ;
|
|
counter: Moving_average.t ;
|
|
max_speed: int option ;
|
|
mutable quota: int ;
|
|
quota_updated: unit Lwt_condition.t ;
|
|
readys: unit Lwt_condition.t ;
|
|
readys_high: (connection * MBytes.t tzresult) Queue.t ;
|
|
readys_low: (connection * MBytes.t tzresult) Queue.t ;
|
|
}
|
|
|
|
and connection = {
|
|
id: int ;
|
|
mutable closed: bool ;
|
|
canceler: Canceler.t ;
|
|
in_param: IO.in_param ;
|
|
out_param: IO.out_param ;
|
|
mutable current_pop: MBytes.t tzresult Lwt.t ;
|
|
mutable current_push: unit tzresult Lwt.t ;
|
|
counter: Moving_average.t ;
|
|
mutable quota: int ;
|
|
mutable last_quota: int ;
|
|
}
|
|
|
|
let cancel (conn : connection) err =
|
|
Lwt_utils.unless conn.closed begin fun () ->
|
|
lwt_debug "Connection closed (%d, %s) " conn.id IO.name >>= fun () ->
|
|
conn.closed <- true ;
|
|
Lwt.catch
|
|
(fun () -> IO.close conn.out_param err)
|
|
(fun _ -> Lwt.return_unit) >>= fun () ->
|
|
Canceler.cancel conn.canceler
|
|
end
|
|
|
|
let waiter st conn =
|
|
assert (Lwt.state conn.current_pop <> Sleep) ;
|
|
conn.current_pop <- IO.pop conn.in_param ;
|
|
Lwt.async begin fun () ->
|
|
conn.current_pop >>= fun res ->
|
|
conn.current_push >>= fun _ ->
|
|
let was_empty =
|
|
Queue.is_empty st.readys_high && Queue.is_empty st.readys_low in
|
|
if conn.quota > 0 then
|
|
Queue.push (conn, res) st.readys_high
|
|
else
|
|
Queue.push (conn, res) st.readys_low ;
|
|
if was_empty then Lwt_condition.broadcast st.readys () ;
|
|
Lwt.return_unit
|
|
end
|
|
|
|
let wait_data st =
|
|
let is_empty =
|
|
Queue.is_empty st.readys_high && Queue.is_empty st.readys_low in
|
|
if is_empty then Lwt_condition.wait st.readys else Lwt.return_unit
|
|
|
|
let check_quota st =
|
|
if st.max_speed <> None && st.quota < 0 then begin
|
|
lwt_debug "scheduler.wait_quota(%s)" IO.name >>= fun () ->
|
|
Lwt_condition.wait st.quota_updated
|
|
end else
|
|
Lwt_unix.yield ()
|
|
|
|
let rec worker_loop st =
|
|
check_quota st >>= fun () ->
|
|
lwt_debug "scheduler.wait(%s)" IO.name >>= fun () ->
|
|
Lwt.pick [
|
|
Canceler.cancelation st.canceler ;
|
|
wait_data st
|
|
] >>= fun () ->
|
|
if Canceler.canceled st.canceler then
|
|
Lwt.return_unit
|
|
else
|
|
let prio, (conn, msg) =
|
|
if not (Queue.is_empty st.readys_high) then
|
|
true, (Queue.pop st.readys_high)
|
|
else
|
|
false, (Queue.pop st.readys_low)
|
|
in
|
|
match msg with
|
|
| Error [Lwt_utils.Canceled] ->
|
|
worker_loop st
|
|
| Error ([Connection_closed |
|
|
Exn ( Lwt_pipe.Closed |
|
|
Unix.Unix_error ((EBADF | ETIMEDOUT), _, _) )]
|
|
as err) ->
|
|
lwt_debug "Connection closed (pop: %d, %s)"
|
|
conn.id IO.name >>= fun () ->
|
|
cancel conn err >>= fun () ->
|
|
worker_loop st
|
|
| Error err ->
|
|
lwt_log_error
|
|
"@[Unexpected error in connection (pop: %d, %s):@ %a@]"
|
|
conn.id IO.name pp_print_error err >>= fun () ->
|
|
cancel conn err >>= fun () ->
|
|
worker_loop st
|
|
| Ok msg ->
|
|
conn.current_push <- begin
|
|
IO.push conn.out_param msg >>= function
|
|
| Ok ()
|
|
| Error [Lwt_utils.Canceled] ->
|
|
return ()
|
|
| Error ([Connection_closed |
|
|
Exn (Unix.Unix_error (EBADF, _, _) |
|
|
Lwt_pipe.Closed)] as err) ->
|
|
lwt_debug "Connection closed (push: %d, %s)"
|
|
conn.id IO.name >>= fun () ->
|
|
cancel conn err >>= fun () ->
|
|
return ()
|
|
| Error err ->
|
|
lwt_log_error
|
|
"@[Unexpected error in connection (push: %d, %s):@ %a@]"
|
|
conn.id IO.name pp_print_error err >>= fun () ->
|
|
cancel conn err >>= fun () ->
|
|
Lwt.return (Error err)
|
|
end ;
|
|
let len = MBytes.length msg in
|
|
lwt_debug "Handle: %d (%d, %s)" len conn.id IO.name >>= fun () ->
|
|
Moving_average.add st.counter len ;
|
|
st.quota <- st.quota - len ;
|
|
Moving_average.add conn.counter len ;
|
|
if prio then conn.quota <- conn.quota - len ;
|
|
waiter st conn ;
|
|
worker_loop st
|
|
|
|
let create max_speed =
|
|
let st = {
|
|
canceler = Canceler.create () ;
|
|
worker = Lwt.return_unit ;
|
|
counter = Moving_average.create ~init:0 ~alpha ;
|
|
max_speed ; quota = unopt ~default:0 max_speed ;
|
|
quota_updated = Lwt_condition.create () ;
|
|
readys = Lwt_condition.create () ;
|
|
readys_high = Queue.create () ;
|
|
readys_low = Queue.create () ;
|
|
} in
|
|
st.worker <-
|
|
Lwt_utils.worker IO.name
|
|
~run:(fun () -> worker_loop st)
|
|
~cancel:(fun () -> Canceler.cancel st.canceler) ;
|
|
st
|
|
|
|
let create_connection st in_param out_param canceler id =
|
|
debug "scheduler(%s).create_connection (%d)" IO.name id ;
|
|
let conn =
|
|
{ id ; closed = false ;
|
|
canceler ;
|
|
in_param ; out_param ;
|
|
current_pop = Lwt.fail Not_found (* dummy *) ;
|
|
current_push = return () ;
|
|
counter = Moving_average.create ~init:0 ~alpha ;
|
|
quota = 0 ; last_quota = 0 ;
|
|
} in
|
|
waiter st conn ;
|
|
conn
|
|
|
|
let update_quota st =
|
|
debug "scheduler(%s).update_quota" IO.name ;
|
|
iter_option st.max_speed ~f:begin fun quota ->
|
|
st.quota <- (min st.quota 0) + quota ;
|
|
Lwt_condition.broadcast st.quota_updated ()
|
|
end ;
|
|
if not (Queue.is_empty st.readys_low) then begin
|
|
let tmp = Queue.create () in
|
|
Queue.iter
|
|
(fun ((conn : connection), _ as msg) ->
|
|
if conn.quota > 0 then
|
|
Queue.push msg st.readys_high
|
|
else
|
|
Queue.push msg tmp)
|
|
st.readys_low ;
|
|
Queue.clear st.readys_low ;
|
|
Queue.transfer tmp st.readys_low ;
|
|
end
|
|
|
|
let shutdown st =
|
|
lwt_debug "--> scheduler(%s).shutdown" IO.name >>= fun () ->
|
|
Canceler.cancel st.canceler >>= fun () ->
|
|
st.worker >>= fun () ->
|
|
lwt_debug "<-- scheduler(%s).shutdown" IO.name >>= fun () ->
|
|
Lwt.return_unit
|
|
|
|
|
|
end
|
|
|
|
module ReadScheduler = Scheduler(struct
|
|
let name = "io_scheduler(read)"
|
|
type in_param = Lwt_unix.file_descr * int
|
|
let pop (fd, maxlen) =
|
|
Lwt.catch
|
|
(fun () ->
|
|
let buf = MBytes.create maxlen in
|
|
Lwt_bytes.read fd buf 0 maxlen >>= fun len ->
|
|
if len = 0 then
|
|
fail Connection_closed
|
|
else
|
|
return (MBytes.sub buf 0 len) )
|
|
(function
|
|
| Unix.Unix_error(Unix.ECONNRESET, _, _) ->
|
|
fail Connection_closed
|
|
| exn ->
|
|
Lwt.return (error_exn exn))
|
|
type out_param = MBytes.t tzresult Lwt_pipe.t
|
|
let push p msg =
|
|
Lwt.catch
|
|
(fun () -> Lwt_pipe.push p (Ok msg) >>= return)
|
|
(fun exn -> fail (Exn exn))
|
|
let close p err =
|
|
Lwt.catch
|
|
(fun () -> Lwt_pipe.push p (Error err))
|
|
(fun _ -> Lwt.return_unit)
|
|
end)
|
|
|
|
module WriteScheduler = Scheduler(struct
|
|
let name = "io_scheduler(write)"
|
|
type in_param = MBytes.t Lwt_pipe.t
|
|
let pop p =
|
|
Lwt.catch
|
|
(fun () -> Lwt_pipe.pop p >>= return)
|
|
(fun _ -> fail (Exn Lwt_pipe.Closed))
|
|
type out_param = Lwt_unix.file_descr
|
|
let push fd buf =
|
|
Lwt.catch
|
|
(fun () ->
|
|
Lwt_utils.write_mbytes fd buf >>= return)
|
|
(function
|
|
| Unix.Unix_error(Unix.ECONNRESET, _, _)
|
|
| Unix.Unix_error(Unix.EPIPE, _, _)
|
|
| Lwt.Canceled
|
|
| End_of_file ->
|
|
fail Connection_closed
|
|
| exn ->
|
|
Lwt.return (error_exn exn))
|
|
let close _p _err = Lwt.return_unit
|
|
end)
|
|
|
|
type connection = {
|
|
id: int ;
|
|
sched: t ;
|
|
conn: Lwt_unix.file_descr ;
|
|
canceler: Canceler.t ;
|
|
read_conn: ReadScheduler.connection ;
|
|
read_queue: MBytes.t tzresult Lwt_pipe.t ;
|
|
write_conn: WriteScheduler.connection ;
|
|
write_queue: MBytes.t Lwt_pipe.t ;
|
|
mutable partial_read: MBytes.t option ;
|
|
}
|
|
|
|
and t = {
|
|
mutable closed: bool ;
|
|
connected: connection Inttbl.t ;
|
|
read_scheduler: ReadScheduler.t ;
|
|
write_scheduler: WriteScheduler.t ;
|
|
max_upload_speed: int option ; (* bytes per second. *)
|
|
max_download_speed: int option ;
|
|
read_buffer_size: int ;
|
|
read_queue_size: int option ;
|
|
write_queue_size: int option ;
|
|
}
|
|
|
|
let reset_quota st =
|
|
debug "--> reset quota" ;
|
|
let { Moving_average.average = current_inflow } =
|
|
Moving_average.stat st.read_scheduler.counter
|
|
and { Moving_average.average = current_outflow } =
|
|
Moving_average.stat st.write_scheduler.counter in
|
|
let nb_conn = Inttbl.length st.connected in
|
|
if nb_conn > 0 then begin
|
|
let fair_read_quota = current_inflow / nb_conn
|
|
and fair_write_quota = current_outflow / nb_conn in
|
|
Inttbl.iter
|
|
(fun _id conn ->
|
|
conn.read_conn.last_quota <- fair_read_quota ;
|
|
conn.read_conn.quota <-
|
|
(min conn.read_conn.quota 0) + fair_read_quota ;
|
|
conn.write_conn.last_quota <- fair_write_quota ;
|
|
conn.write_conn.quota <-
|
|
(min conn.write_conn.quota 0) + fair_write_quota ; )
|
|
st.connected
|
|
end ;
|
|
ReadScheduler.update_quota st.read_scheduler ;
|
|
WriteScheduler.update_quota st.write_scheduler
|
|
|
|
let create
|
|
?max_upload_speed ?max_download_speed
|
|
?read_queue_size ?write_queue_size
|
|
~read_buffer_size
|
|
() =
|
|
log_info "--> create" ;
|
|
let st = {
|
|
closed = false ;
|
|
connected = Inttbl.create 53 ;
|
|
read_scheduler = ReadScheduler.create max_download_speed ;
|
|
write_scheduler = WriteScheduler.create max_upload_speed ;
|
|
max_upload_speed ;
|
|
max_download_speed ;
|
|
read_buffer_size ;
|
|
read_queue_size ;
|
|
write_queue_size ;
|
|
} in
|
|
Moving_average.on_update (fun () -> reset_quota st) ;
|
|
st
|
|
|
|
exception Closed
|
|
|
|
let read_size = function
|
|
| Ok buf -> (Sys.word_size / 8) * 8 + MBytes.length buf + Lwt_pipe.push_overhead
|
|
| Error _ -> 0 (* we push Error only when we close the socket,
|
|
we don't fear memory leaks in that case... *)
|
|
|
|
let write_size mbytes =
|
|
(Sys.word_size / 8) * 6 + MBytes.length mbytes + Lwt_pipe.push_overhead
|
|
|
|
let register =
|
|
let cpt = ref 0 in
|
|
fun st conn ->
|
|
if st.closed then begin
|
|
Lwt.async (fun () -> Lwt_utils.safe_close conn) ;
|
|
raise Closed
|
|
end else begin
|
|
let id = incr cpt; !cpt in
|
|
let canceler = Canceler.create () in
|
|
let read_size =
|
|
map_option st.read_queue_size ~f:(fun v -> v, read_size) in
|
|
let write_size =
|
|
map_option st.write_queue_size ~f:(fun v -> v, write_size) in
|
|
let read_queue = Lwt_pipe.create ?size:read_size () in
|
|
let write_queue = Lwt_pipe.create ?size:write_size () in
|
|
let read_conn =
|
|
ReadScheduler.create_connection
|
|
st.read_scheduler (conn, st.read_buffer_size) read_queue canceler id
|
|
and write_conn =
|
|
WriteScheduler.create_connection
|
|
st.write_scheduler write_queue conn canceler id in
|
|
Canceler.on_cancel canceler begin fun () ->
|
|
Inttbl.remove st.connected id ;
|
|
Moving_average.destroy read_conn.counter ;
|
|
Moving_average.destroy write_conn.counter ;
|
|
Lwt_pipe.close write_queue ;
|
|
Lwt_pipe.close read_queue ;
|
|
Lwt_utils.safe_close conn
|
|
end ;
|
|
let conn = {
|
|
sched = st ; id ; conn ; canceler ;
|
|
read_queue ; read_conn ;
|
|
write_queue ; write_conn ;
|
|
partial_read = None ;
|
|
} in
|
|
Inttbl.add st.connected id conn ;
|
|
log_info "--> register (%d)" conn.id ;
|
|
conn
|
|
end
|
|
|
|
let write { write_queue } msg =
|
|
Lwt.catch
|
|
(fun () -> Lwt_pipe.push write_queue msg >>= return)
|
|
(fun _ -> fail Connection_closed)
|
|
let write_now { write_queue } msg = Lwt_pipe.push_now write_queue msg
|
|
|
|
let read_from conn ?pos ?len buf msg =
|
|
let maxlen = MBytes.length buf in
|
|
let pos = unopt ~default:0 pos in
|
|
assert (0 <= pos && pos < maxlen) ;
|
|
let len = unopt ~default:(maxlen - pos) len in
|
|
assert (len <= maxlen - pos) ;
|
|
match msg with
|
|
| Ok msg ->
|
|
let msg_len = MBytes.length msg in
|
|
let read_len = min len msg_len in
|
|
MBytes.blit msg 0 buf pos read_len ;
|
|
if read_len < msg_len then
|
|
conn.partial_read <-
|
|
Some (MBytes.sub msg read_len (msg_len - read_len)) ;
|
|
Ok read_len
|
|
| Error _ ->
|
|
Error [Connection_closed]
|
|
|
|
let read_now conn ?pos ?len buf =
|
|
match conn.partial_read with
|
|
| Some msg ->
|
|
conn.partial_read <- None ;
|
|
Some (read_from conn ?pos ?len buf (Ok msg))
|
|
| None ->
|
|
try
|
|
map_option
|
|
~f:(read_from conn ?pos ?len buf)
|
|
(Lwt_pipe.pop_now conn.read_queue)
|
|
with Lwt_pipe.Closed -> Some (Error [Connection_closed])
|
|
|
|
let read conn ?pos ?len buf =
|
|
match conn.partial_read with
|
|
| Some msg ->
|
|
conn.partial_read <- None ;
|
|
Lwt.return (read_from conn ?pos ?len buf (Ok msg))
|
|
| None ->
|
|
Lwt.catch
|
|
(fun () ->
|
|
Lwt_pipe.pop conn.read_queue >|= fun msg ->
|
|
read_from conn ?pos ?len buf msg)
|
|
(fun _ -> fail Connection_closed)
|
|
|
|
let read_full conn ?pos ?len buf =
|
|
let maxlen = MBytes.length buf in
|
|
let pos = unopt ~default:0 pos in
|
|
let len = unopt ~default:(maxlen - pos) len in
|
|
assert (0 <= pos && pos < maxlen) ;
|
|
assert (len <= maxlen - pos) ;
|
|
let rec loop pos len =
|
|
if len = 0 then
|
|
return ()
|
|
else
|
|
read conn ~pos ~len buf >>=? fun read_len ->
|
|
loop (pos + read_len) (len - read_len) in
|
|
loop pos len
|
|
|
|
let convert ~ws ~rs =
|
|
{ Stat.total_sent = ws.Moving_average.total ;
|
|
total_recv = rs.Moving_average.total ;
|
|
current_outflow = ws.average ;
|
|
current_inflow = rs.average ;
|
|
}
|
|
|
|
let global_stat { read_scheduler ; write_scheduler } =
|
|
let rs = Moving_average.stat read_scheduler.counter
|
|
and ws = Moving_average.stat write_scheduler.counter in
|
|
convert ~rs ~ws
|
|
|
|
let stat { read_conn ; write_conn} =
|
|
let rs = Moving_average.stat read_conn.counter
|
|
and ws = Moving_average.stat write_conn.counter in
|
|
convert ~rs ~ws
|
|
|
|
let close ?timeout conn =
|
|
lwt_log_info "--> close (%d)" conn.id >>= fun () ->
|
|
Inttbl.remove conn.sched.connected conn.id ;
|
|
Lwt_pipe.close conn.write_queue ;
|
|
begin
|
|
match timeout with
|
|
| None ->
|
|
return (Canceler.cancelation conn.canceler)
|
|
| Some timeout ->
|
|
Lwt_utils.with_timeout
|
|
~canceler:conn.canceler timeout begin fun canceler ->
|
|
return (Canceler.cancelation canceler)
|
|
end
|
|
end >>=? fun _ ->
|
|
conn.write_conn.current_push >>= fun res ->
|
|
lwt_log_info "<-- close (%d)" conn.id >>= fun () ->
|
|
Lwt.return res
|
|
|
|
let iter_connection { connected } f =
|
|
Inttbl.iter f connected
|
|
|
|
let shutdown ?timeout st =
|
|
lwt_log_info "--> shutdown" >>= fun () ->
|
|
st.closed <- true ;
|
|
ReadScheduler.shutdown st.read_scheduler >>= fun () ->
|
|
Inttbl.fold
|
|
(fun _peer_id conn acc -> close ?timeout conn >>= fun _ -> acc)
|
|
st.connected
|
|
Lwt.return_unit >>= fun () ->
|
|
WriteScheduler.shutdown st.write_scheduler >>= fun () ->
|
|
lwt_log_info "<-- shutdown" >>= fun () ->
|
|
Lwt.return_unit
|