Shell/P2p: Use `Lwt_utils.{read/write}

This commit is contained in:
Grégoire Henry 2016-11-15 02:04:36 +01:00
parent 6afcc1ecdd
commit cfba0d9cb7

View File

@ -196,19 +196,8 @@ module Make (P: PARAMS) = struct
module BE = EndianBigstring.BigEndian module BE = EndianBigstring.BigEndian
(** Read a message from a file descriptor and returns (tag, msg) *) (* read a message from a TCP socket *)
let read fd buf = let recv_msg fd buf =
let rec read_into_exactly ?(pos=0) ?len descr buf =
let len = match len with None -> MBytes.length buf | Some l -> l in
let rec inner pos len =
if len = 0 then
Lwt.return_unit
else
Lwt_bytes.read descr buf pos len >>= fun nb_read ->
inner (pos + nb_read) (len - nb_read)
in
inner pos len
in
catch (fun () -> catch (fun () ->
Lwt_bytes.recv fd buf 0 4 [ Lwt_unix.MSG_PEEK ] >>= fun hdrlen -> Lwt_bytes.recv fd buf 0 4 [ Lwt_unix.MSG_PEEK ] >>= fun hdrlen ->
if hdrlen <> 4 then begin if hdrlen <> 4 then begin
@ -223,42 +212,26 @@ module Make (P: PARAMS) = struct
return None return None
end end
else else
read_into_exactly fd buf ~pos:4 ~len >|= fun () -> Lwt_utils.read_mbytes fd buf ~pos:4 ~len >|= fun () ->
let tag = BE.get_uint8 buf 4 in let tag = BE.get_uint8 buf 4 in
Some (tag, MBytes.sub buf 4 len)) let msg = MBytes.sub buf 4 len in
match max_length tag with
| Some len when MBytes.length msg > len -> None
| _ -> Data_encoding.Binary.of_bytes msg_encoding msg)
(function (function
| Unix.Unix_error (_err, _, _) -> return None | Unix.Unix_error (_err, _, _) -> return None
| e -> fail e) | e -> fail e)
(** Write a message to file descriptor. *)
let write ?(pos=0) ?len descr buf =
let len = match len with None -> MBytes.length buf | Some l -> l in
catch
(fun () ->
Lwt_bytes.write descr buf pos len >>= fun _nb_written ->
return true)
(function
| Unix.Unix_error _ -> return false
| e -> fail e)
(* read a message from a TCP socket *)
let recv_msg fd buf =
read fd buf >|= function
| None -> None
| Some (tag, msg) ->
match max_length tag with
| Some len when MBytes.length msg > len -> None
| _ -> Data_encoding.Binary.of_bytes msg_encoding msg
(* send a message over a TCP socket *) (* send a message over a TCP socket *)
let send_msg fd buf packet = let send_msg fd buf packet =
catch catch
(fun () -> (fun () ->
match Data_encoding.Binary.write msg_encoding packet buf 4 with match Data_encoding.Binary.write msg_encoding packet buf 4 with
| None -> return false | None -> return_false
| Some len -> | Some len ->
BE.set_int32 buf 0 @@ Int32.of_int (len - 4); BE.set_int32 buf 0 @@ Int32.of_int (len - 4);
write fd buf ~len Lwt_utils.write_mbytes fd buf ~len >>= fun () ->
return_true
) )
(fun exn -> Lwt.fail exn) (fun exn -> Lwt.fail exn)