Shell/P2p: do not read the tag twice...

This commit is contained in:
Grégoire Henry 2016-11-15 02:09:17 +01:00
parent 9d67c1fea5
commit 6121c518d4

View File

@ -129,103 +129,94 @@ module Make (P: PARAMS) = struct
transmission (and needs not being replied). The [Unkown] packet is transmission (and needs not being replied). The [Unkown] packet is
not a real kind of packet, it means that something indecypherable not a real kind of packet, it means that something indecypherable
was transmitted. *) was transmitted. *)
type hello = {
gid: gid;
port: int option;
versions: version list;
}
let hello_encoding =
let open Data_encoding in
conv
(fun { gid; port; versions } -> (gid, port, versions))
(fun (gid, port, versions) -> { gid; port; versions })
(obj3
(req "gid" (Fixed.string 16)) (* TODO: get rid of constant *)
(opt "port" int16)
(req "versions" (Variable.list version_encoding)))
type msg = type msg =
| Connect of hello | Connect of {
gid : string ;
port : int option ;
versions : version list ;
}
| Disconnect | Disconnect
| Advertise of point list
| Bootstrap | Bootstrap
| Advertise of point list
| Message of P.msg | Message of P.msg
let msg_encoding = let msg_encoding =
let open Data_encoding in let open Data_encoding in
union ~tag_size:`Uint8 begin [ union ~tag_size:`Uint16
case ~tag:0x00 hello_encoding ([ case ~tag:0x00
(function Connect hello -> Some hello | _ -> None) (obj3
(fun hello -> Connect hello); (req "gid" (Fixed.string gid_length))
(req "port" uint16)
(req "versions" (Variable.list version_encoding)))
(function
| Connect { gid ; port ; versions } ->
let port = match port with None -> 0 | Some port -> port in
Some (gid, port, versions)
| _ -> None)
(fun (gid, port, versions) ->
let port = if port = 0 then None else Some port in
Connect { gid ; port ; versions });
case ~tag:0x01 null case ~tag:0x01 null
(function Disconnect -> Some () | _ -> None) (function Disconnect -> Some () | _ -> None)
(fun () -> Disconnect); (fun () -> Disconnect);
case ~tag:0x04 (Variable.list point_encoding) case ~tag:0x02 null
(function Advertise points -> Some points | _ -> None)
(fun points -> Advertise points);
case ~tag:0x05 null
(function Bootstrap -> Some () | _ -> None) (function Bootstrap -> Some () | _ -> None)
(fun () -> Bootstrap); (fun () -> Bootstrap);
case ~tag:0x03 (Variable.list point_encoding)
(function Advertise points -> Some points | _ -> None)
(fun points -> Advertise points);
] @ ] @
ListLabels.map P.encodings ~f:begin function Encoding { tag; encoding; wrap; unwrap } -> ListLabels.map P.encodings
~f:(function Encoding { tag ; encoding ; wrap ; unwrap } ->
case ~tag encoding case ~tag encoding
(function Message msg -> unwrap msg | _ -> None) (function Message msg -> unwrap msg | _ -> None)
(fun msg -> Message (wrap msg)) (fun msg -> Message (wrap msg))))
end
end
let max_length = function let hdrlen = 2
| 0 -> Some 1024 let maxlen = hdrlen + 2 lsl 16
| 1 -> Some 0
| 2 -> Some 0
| 3 -> Some 0
| 4 -> Some (1 + 1000 * 17) (* tag + 1000 * max (point size) *)
| 5 -> Some 0
| n -> ListLabels.fold_left P.encodings ~init:None ~f:begin fun a -> function
Encoding { tag; max_length } -> if tag = n then max_length else a
end
module BE = EndianBigstring.BigEndian
(* read a message from a TCP socket *) (* read a message from a TCP socket *)
let recv_msg fd buf = let recv_msg fd buf =
catch (fun () -> catch
Lwt_bytes.recv fd buf 0 4 [ Lwt_unix.MSG_PEEK ] >>= fun hdrlen -> (fun () ->
if hdrlen <> 4 then begin assert (MBytes.length buf >= 2 lsl 16) ;
debug "read: could not read enough bytes to determine message size, aborting"; Lwt_utils.read_mbytes ~len:hdrlen fd buf >>= fun () ->
return None let len = EndianBigstring.BigEndian.get_uint16 buf 0 in
end (* TODO timeout read ??? *)
else Lwt_utils.read_mbytes ~len fd buf >>= fun () ->
Lwt_bytes.read fd buf 0 4 >>= fun _hdrlen -> (* TODO conditionnaly decrypt payload... ?? *)
let len = Int32.to_int (BE.get_int32 buf 0) in match Data_encoding.Binary.read msg_encoding buf 0 len with
if len < 0 || len > MBytes.length buf then begin | None ->
debug "read: invalid message size %d" len; (* TODO track invalid message *)
return None return Disconnect
end | Some (read, _) when read <> len ->
else (* TODO track invalid message *)
Lwt_utils.read_mbytes fd buf ~pos:4 ~len >|= fun () -> return Disconnect
let tag = BE.get_uint8 buf 4 in | Some (_, msg) ->
let msg = MBytes.sub buf 4 len in Lwt.return msg)
match max_length tag with
| Some len when MBytes.length msg > len -> None
| _ -> Data_encoding.Binary.of_bytes msg_encoding msg)
(function (function
| Unix.Unix_error (_err, _, _) -> return None | Unix.Unix_error _ -> return Disconnect
| e -> fail e) | e -> fail e)
(* send a message over a TCP socket *) (* send a message over a TCP socket *)
let send_msg fd buf packet = let send_msg fd buf msg =
catch catch
(fun () -> (fun () ->
match Data_encoding.Binary.write msg_encoding packet buf 4 with match Data_encoding.Binary.write msg_encoding msg buf hdrlen with
| None -> return_false | None -> return_false
| Some len -> | Some len ->
BE.set_int32 buf 0 @@ Int32.of_int (len - 4); if len > maxlen then
Lwt_utils.write_mbytes fd buf ~len >>= fun () -> return_false
return_true else begin
) EndianBigstring.BigEndian.set_int16 buf 0 (len - hdrlen) ;
(fun exn -> Lwt.fail exn) (* TODO conditionnaly encrypt payload... ? *)
(* TODO timeout write ??? *)
Lwt_utils.write_mbytes ~len fd buf >>= fun () ->
return true
end)
(function
| Unix.Unix_error _ -> return_false
| e -> fail e)
(* A peer handle, as a record-encoded object, abstract from the (* A peer handle, as a record-encoded object, abstract from the
outside world. A hidden Lwt worker is associated to a peer at its outside world. A hidden Lwt worker is associated to a peer at its
@ -377,17 +368,18 @@ module Make (P: PARAMS) = struct
let cancelation, cancel, on_cancel = canceler () in let cancelation, cancel, on_cancel = canceler () in
(* a cancelable reception *) (* a cancelable reception *)
let recv buf = let recv buf =
pick [ (recv_msg socket buf >|= function Some p -> p | None -> Disconnect); pick [ recv_msg socket buf ;
(cancelation () >>= fun () -> return Disconnect) ] in (cancelation () >>= fun () -> return Disconnect) ] in
(* First step: send and receive credentials, makes no difference (* First step: send and receive credentials, makes no difference
whether we're trying to connect to a peer or checking an incoming whether we're trying to connect to a peer or checking an incoming
connection, both parties must first present themselves. *) connection, both parties must first present themselves. *)
let rec connect buf = let rec connect buf =
send_msg socket buf (Connect { gid = my_gid ; send_msg socket buf
(Connect { gid = my_gid ;
port = config.incoming_port ; port = config.incoming_port ;
versions = P.supported_versions }) >>= fun _ -> versions = P.supported_versions }) >>= fun _ ->
pick [ (LU.sleep limits.peer_answer_timeout >>= fun () -> return Disconnect) ; pick [ (LU.sleep limits.peer_answer_timeout >>= fun () -> return Disconnect) ;
recv buf ] >>= function recv_msg socket buf ] >>= function
| Connect { gid; port = listening_port; versions } -> | Connect { gid; port = listening_port; versions } ->
debug "(%a) connection requested from %a @ %a:%d" debug "(%a) connection requested from %a @ %a:%d"
pp_gid my_gid pp_gid gid Ipaddr.pp_hum addr port ; pp_gid my_gid pp_gid gid Ipaddr.pp_hum addr port ;
@ -458,7 +450,7 @@ module Make (P: PARAMS) = struct
(* Launch the worker *) (* Launch the worker *)
receiver () receiver ()
in in
let buf = MBytes.create 0x100_000 in let buf = MBytes.create maxlen in
on_cancel (fun () -> on_cancel (fun () ->
send_msg socket buf Disconnect >>= fun _ -> send_msg socket buf Disconnect >>= fun _ ->
LU.close socket >>= fun _ -> LU.close socket >>= fun _ ->
@ -542,8 +534,7 @@ module Make (P: PARAMS) = struct
(* A good random string so it is probably unique on the network *) (* A good random string so it is probably unique on the network *)
let fresh_gid () = let fresh_gid () =
Bytes.to_string @@ Sodium.Random.Bytes.generate 16 Bytes.to_string @@ Sodium.Random.Bytes.generate gid_length
(* The (fixed size) broadcast frame. *) (* The (fixed size) broadcast frame. *)
let discovery_message_encoding = let discovery_message_encoding =