(**************************************************************************) (* *) (* Copyright (c) 2014 - 2016. *) (* Dynamic Ledger Solutions, Inc. *) (* *) (* All rights reserved. No warranty, explicit or implicit, provided. *) (* *) (**************************************************************************) (* TODO decide whether we need to preallocate buffers or not. *) open P2p_types include Logging.Make (struct let name = "p2p.io-scheduler" end) module Inttbl = Hashtbl.Make(struct type t = int let equal (x: int) (y: int) = x = y let hash = Hashtbl.hash end) let alpha = 0.2 module type IO = sig val name: string type in_param val pop: in_param -> MBytes.t tzresult Lwt.t type out_param val push: out_param -> MBytes.t -> unit tzresult Lwt.t val close: out_param -> error list -> unit Lwt.t end type error += Connection_closed module Scheduler(IO : IO) = struct type t = { canceler: Canceler.t ; mutable worker: unit Lwt.t ; counter: Moving_average.t ; max_speed: int option ; mutable quota: int ; quota_updated: unit Lwt_condition.t ; readys: unit Lwt_condition.t ; readys_high: (connection * MBytes.t tzresult) Queue.t ; readys_low: (connection * MBytes.t tzresult) Queue.t ; } and connection = { id: int ; mutable closed: bool ; canceler: Canceler.t ; in_param: IO.in_param ; out_param: IO.out_param ; mutable current_pop: MBytes.t tzresult Lwt.t ; mutable current_push: unit tzresult Lwt.t ; counter: Moving_average.t ; mutable quota: int ; mutable last_quota: int ; } let cancel (conn : connection) err = Lwt_utils.unless conn.closed begin fun () -> lwt_debug "Connection closed (%d, %s) " conn.id IO.name >>= fun () -> conn.closed <- true ; Lwt.catch (fun () -> IO.close conn.out_param err) (fun _ -> Lwt.return_unit) >>= fun () -> Canceler.cancel conn.canceler end let waiter st conn = assert (Lwt.state conn.current_pop <> Sleep) ; conn.current_pop <- IO.pop conn.in_param ; Lwt.async begin fun () -> conn.current_pop >>= fun res -> conn.current_push >>= fun _ -> let was_empty = Queue.is_empty st.readys_high && Queue.is_empty st.readys_low in if conn.quota > 0 then Queue.push (conn, res) st.readys_high else Queue.push (conn, res) st.readys_low ; if was_empty then Lwt_condition.broadcast st.readys () ; Lwt.return_unit end let wait_data st = let is_empty = Queue.is_empty st.readys_high && Queue.is_empty st.readys_low in if is_empty then Lwt_condition.wait st.readys else Lwt.return_unit let check_quota st = if st.max_speed <> None && st.quota < 0 then begin lwt_debug "scheduler.wait_quota(%s)" IO.name >>= fun () -> Lwt_condition.wait st.quota_updated end else Lwt_unix.yield () let rec worker_loop st = check_quota st >>= fun () -> lwt_debug "scheduler.wait(%s)" IO.name >>= fun () -> Lwt.pick [ Canceler.cancelation st.canceler ; wait_data st ] >>= fun () -> if Canceler.canceled st.canceler then Lwt.return_unit else let prio, (conn, msg) = if not (Queue.is_empty st.readys_high) then true, (Queue.pop st.readys_high) else false, (Queue.pop st.readys_low) in match msg with | Error [Lwt_utils.Canceled] -> worker_loop st | Error ([Connection_closed | Exn ( Lwt_pipe.Closed | Unix.Unix_error ((EBADF | ETIMEDOUT), _, _) )] as err) -> lwt_debug "Connection closed (pop: %d, %s)" conn.id IO.name >>= fun () -> cancel conn err >>= fun () -> worker_loop st | Error err -> lwt_log_error "@[Unexpected error in connection (pop: %d, %s):@ %a@]" conn.id IO.name pp_print_error err >>= fun () -> cancel conn err >>= fun () -> worker_loop st | Ok msg -> conn.current_push <- begin IO.push conn.out_param msg >>= function | Ok () | Error [Lwt_utils.Canceled] -> return () | Error ([Connection_closed | Exn (Unix.Unix_error (EBADF, _, _) | Lwt_pipe.Closed)] as err) -> lwt_debug "Connection closed (push: %d, %s)" conn.id IO.name >>= fun () -> cancel conn err >>= fun () -> return () | Error err -> lwt_log_error "@[Unexpected error in connection (push: %d, %s):@ %a@]" conn.id IO.name pp_print_error err >>= fun () -> cancel conn err >>= fun () -> Lwt.return (Error err) end ; let len = MBytes.length msg in lwt_debug "Handle: %d (%d, %s)" len conn.id IO.name >>= fun () -> Moving_average.add st.counter len ; st.quota <- st.quota - len ; Moving_average.add conn.counter len ; if prio then conn.quota <- conn.quota - len ; waiter st conn ; worker_loop st let create max_speed = let st = { canceler = Canceler.create () ; worker = Lwt.return_unit ; counter = Moving_average.create ~init:0 ~alpha ; max_speed ; quota = unopt ~default:0 max_speed ; quota_updated = Lwt_condition.create () ; readys = Lwt_condition.create () ; readys_high = Queue.create () ; readys_low = Queue.create () ; } in st.worker <- Lwt_utils.worker IO.name ~run:(fun () -> worker_loop st) ~cancel:(fun () -> Canceler.cancel st.canceler) ; st let create_connection st in_param out_param canceler id = debug "scheduler(%s).create_connection (%d)" IO.name id ; let conn = { id ; closed = false ; canceler ; in_param ; out_param ; current_pop = Lwt.fail Not_found (* dummy *) ; current_push = return () ; counter = Moving_average.create ~init:0 ~alpha ; quota = 0 ; last_quota = 0 ; } in waiter st conn ; conn let update_quota st = debug "scheduler(%s).update_quota" IO.name ; iter_option st.max_speed ~f:begin fun quota -> st.quota <- (min st.quota 0) + quota ; Lwt_condition.broadcast st.quota_updated () end ; if not (Queue.is_empty st.readys_low) then begin let tmp = Queue.create () in Queue.iter (fun ((conn : connection), _ as msg) -> if conn.quota > 0 then Queue.push msg st.readys_high else Queue.push msg tmp) st.readys_low ; Queue.clear st.readys_low ; Queue.transfer tmp st.readys_low ; end let shutdown st = lwt_debug "--> scheduler(%s).shutdown" IO.name >>= fun () -> Canceler.cancel st.canceler >>= fun () -> st.worker >>= fun () -> lwt_debug "<-- scheduler(%s).shutdown" IO.name >>= fun () -> Lwt.return_unit end module ReadScheduler = Scheduler(struct let name = "io_scheduler(read)" type in_param = Lwt_unix.file_descr * int let pop (fd, maxlen) = Lwt.catch (fun () -> let buf = MBytes.create maxlen in Lwt_bytes.read fd buf 0 maxlen >>= fun len -> if len = 0 then fail Connection_closed else return (MBytes.sub buf 0 len) ) (function | Unix.Unix_error(Unix.ECONNRESET, _, _) -> fail Connection_closed | exn -> Lwt.return (error_exn exn)) type out_param = MBytes.t tzresult Lwt_pipe.t let push p msg = Lwt.catch (fun () -> Lwt_pipe.push p (Ok msg) >>= return) (fun exn -> fail (Exn exn)) let close p err = Lwt.catch (fun () -> Lwt_pipe.push p (Error err)) (fun _ -> Lwt.return_unit) end) module WriteScheduler = Scheduler(struct let name = "io_scheduler(write)" type in_param = MBytes.t Lwt_pipe.t let pop p = Lwt.catch (fun () -> Lwt_pipe.pop p >>= return) (fun _ -> fail (Exn Lwt_pipe.Closed)) type out_param = Lwt_unix.file_descr let push fd buf = Lwt.catch (fun () -> Lwt_utils.write_mbytes fd buf >>= return) (function | Unix.Unix_error(Unix.ECONNRESET, _, _) | Unix.Unix_error(Unix.EPIPE, _, _) | Lwt.Canceled | End_of_file -> fail Connection_closed | exn -> Lwt.return (error_exn exn)) let close _p _err = Lwt.return_unit end) type connection = { id: int ; sched: t ; conn: Lwt_unix.file_descr ; canceler: Canceler.t ; read_conn: ReadScheduler.connection ; read_queue: MBytes.t tzresult Lwt_pipe.t ; write_conn: WriteScheduler.connection ; write_queue: MBytes.t Lwt_pipe.t ; mutable partial_read: MBytes.t option ; } and t = { mutable closed: bool ; connected: connection Inttbl.t ; read_scheduler: ReadScheduler.t ; write_scheduler: WriteScheduler.t ; max_upload_speed: int option ; (* bytes per second. *) max_download_speed: int option ; read_buffer_size: int ; read_queue_size: int option ; write_queue_size: int option ; } let reset_quota st = debug "--> reset quota" ; let { Moving_average.average = current_inflow } = Moving_average.stat st.read_scheduler.counter and { Moving_average.average = current_outflow } = Moving_average.stat st.write_scheduler.counter in let nb_conn = Inttbl.length st.connected in if nb_conn > 0 then begin let fair_read_quota = current_inflow / nb_conn and fair_write_quota = current_outflow / nb_conn in Inttbl.iter (fun _id conn -> conn.read_conn.last_quota <- fair_read_quota ; conn.read_conn.quota <- (min conn.read_conn.quota 0) + fair_read_quota ; conn.write_conn.last_quota <- fair_write_quota ; conn.write_conn.quota <- (min conn.write_conn.quota 0) + fair_write_quota ; ) st.connected end ; ReadScheduler.update_quota st.read_scheduler ; WriteScheduler.update_quota st.write_scheduler let create ?max_upload_speed ?max_download_speed ?read_queue_size ?write_queue_size ~read_buffer_size () = log_info "--> create" ; let st = { closed = false ; connected = Inttbl.create 53 ; read_scheduler = ReadScheduler.create max_download_speed ; write_scheduler = WriteScheduler.create max_upload_speed ; max_upload_speed ; max_download_speed ; read_buffer_size ; read_queue_size ; write_queue_size ; } in Moving_average.on_update (fun () -> reset_quota st) ; st exception Closed let read_size = function | Ok buf -> (Sys.word_size / 8) * 8 + MBytes.length buf + Lwt_pipe.push_overhead | Error _ -> 0 (* we push Error only when we close the socket, we don't fear memory leaks in that case... *) let write_size mbytes = (Sys.word_size / 8) * 6 + MBytes.length mbytes + Lwt_pipe.push_overhead let register = let cpt = ref 0 in fun st conn -> if st.closed then begin Lwt.async (fun () -> Lwt_utils.safe_close conn) ; raise Closed end else begin let id = incr cpt; !cpt in let canceler = Canceler.create () in let read_size = map_option st.read_queue_size ~f:(fun v -> v, read_size) in let write_size = map_option st.write_queue_size ~f:(fun v -> v, write_size) in let read_queue = Lwt_pipe.create ?size:read_size () in let write_queue = Lwt_pipe.create ?size:write_size () in let read_conn = ReadScheduler.create_connection st.read_scheduler (conn, st.read_buffer_size) read_queue canceler id and write_conn = WriteScheduler.create_connection st.write_scheduler write_queue conn canceler id in Canceler.on_cancel canceler begin fun () -> Inttbl.remove st.connected id ; Moving_average.destroy read_conn.counter ; Moving_average.destroy write_conn.counter ; Lwt_pipe.close write_queue ; Lwt_pipe.close read_queue ; Lwt_utils.safe_close conn end ; let conn = { sched = st ; id ; conn ; canceler ; read_queue ; read_conn ; write_queue ; write_conn ; partial_read = None ; } in Inttbl.add st.connected id conn ; log_info "--> register (%d)" conn.id ; conn end let write { write_queue } msg = Lwt.catch (fun () -> Lwt_pipe.push write_queue msg >>= return) (fun _ -> fail Connection_closed) let write_now { write_queue } msg = Lwt_pipe.push_now write_queue msg let read_from conn ?pos ?len buf msg = let maxlen = MBytes.length buf in let pos = unopt ~default:0 pos in assert (0 <= pos && pos < maxlen) ; let len = unopt ~default:(maxlen - pos) len in assert (len <= maxlen - pos) ; match msg with | Ok msg -> let msg_len = MBytes.length msg in let read_len = min len msg_len in MBytes.blit msg 0 buf pos read_len ; if read_len < msg_len then conn.partial_read <- Some (MBytes.sub msg read_len (msg_len - read_len)) ; Ok read_len | Error _ -> Error [Connection_closed] let read_now conn ?pos ?len buf = match conn.partial_read with | Some msg -> conn.partial_read <- None ; Some (read_from conn ?pos ?len buf (Ok msg)) | None -> try map_option ~f:(read_from conn ?pos ?len buf) (Lwt_pipe.pop_now conn.read_queue) with Lwt_pipe.Closed -> Some (Error [Connection_closed]) let read conn ?pos ?len buf = match conn.partial_read with | Some msg -> conn.partial_read <- None ; Lwt.return (read_from conn ?pos ?len buf (Ok msg)) | None -> Lwt.catch (fun () -> Lwt_pipe.pop conn.read_queue >|= fun msg -> read_from conn ?pos ?len buf msg) (fun _ -> fail Connection_closed) let read_full conn ?pos ?len buf = let maxlen = MBytes.length buf in let pos = unopt ~default:0 pos in let len = unopt ~default:(maxlen - pos) len in assert (0 <= pos && pos < maxlen) ; assert (len <= maxlen - pos) ; let rec loop pos len = if len = 0 then return () else read conn ~pos ~len buf >>=? fun read_len -> loop (pos + read_len) (len - read_len) in loop pos len let convert ~ws ~rs = { Stat.total_sent = ws.Moving_average.total ; total_recv = rs.Moving_average.total ; current_outflow = ws.average ; current_inflow = rs.average ; } let global_stat { read_scheduler ; write_scheduler } = let rs = Moving_average.stat read_scheduler.counter and ws = Moving_average.stat write_scheduler.counter in convert ~rs ~ws let stat { read_conn ; write_conn} = let rs = Moving_average.stat read_conn.counter and ws = Moving_average.stat write_conn.counter in convert ~rs ~ws let close ?timeout conn = lwt_log_info "--> close (%d)" conn.id >>= fun () -> Inttbl.remove conn.sched.connected conn.id ; Lwt_pipe.close conn.write_queue ; begin match timeout with | None -> return (Canceler.cancelation conn.canceler) | Some timeout -> Lwt_utils.with_timeout ~canceler:conn.canceler timeout begin fun canceler -> return (Canceler.cancelation canceler) end end >>=? fun _ -> conn.write_conn.current_push >>= fun res -> lwt_log_info "<-- close (%d)" conn.id >>= fun () -> Lwt.return res let iter_connection { connected } f = Inttbl.iter f connected let shutdown ?timeout st = lwt_log_info "--> shutdown" >>= fun () -> st.closed <- true ; ReadScheduler.shutdown st.read_scheduler >>= fun () -> Inttbl.fold (fun _peer_id conn acc -> close ?timeout conn >>= fun _ -> acc) st.connected Lwt.return_unit >>= fun () -> WriteScheduler.shutdown st.write_scheduler >>= fun () -> lwt_log_info "<-- shutdown" >>= fun () -> Lwt.return_unit