Shell/P2p: do not read the tag twice...

This commit is contained in:
Grégoire Henry 2016-11-15 02:09:17 +01:00
parent 9d67c1fea5
commit 6121c518d4

View File

@ -129,103 +129,94 @@ module Make (P: PARAMS) = struct
transmission (and needs not being replied). The [Unkown] packet is
not a real kind of packet, it means that something indecypherable
was transmitted. *)
type hello = {
gid: gid;
port: int option;
versions: version list;
}
let hello_encoding =
let open Data_encoding in
conv
(fun { gid; port; versions } -> (gid, port, versions))
(fun (gid, port, versions) -> { gid; port; versions })
(obj3
(req "gid" (Fixed.string 16)) (* TODO: get rid of constant *)
(opt "port" int16)
(req "versions" (Variable.list version_encoding)))
type msg =
| Connect of hello
| Connect of {
gid : string ;
port : int option ;
versions : version list ;
}
| Disconnect
| Advertise of point list
| Bootstrap
| Advertise of point list
| Message of P.msg
let msg_encoding =
let open Data_encoding in
union ~tag_size:`Uint8 begin [
case ~tag:0x00 hello_encoding
(function Connect hello -> Some hello | _ -> None)
(fun hello -> Connect hello);
case ~tag:0x01 null
(function Disconnect -> Some () | _ -> None)
(fun () -> Disconnect);
case ~tag:0x04 (Variable.list point_encoding)
(function Advertise points -> Some points | _ -> None)
(fun points -> Advertise points);
case ~tag:0x05 null
(function Bootstrap -> Some () | _ -> None)
(fun () -> Bootstrap);
] @
ListLabels.map P.encodings ~f:begin function Encoding { tag; encoding; wrap; unwrap } ->
case ~tag encoding
(function Message msg -> unwrap msg | _ -> None)
(fun msg -> Message (wrap msg))
end
end
union ~tag_size:`Uint16
([ case ~tag:0x00
(obj3
(req "gid" (Fixed.string gid_length))
(req "port" uint16)
(req "versions" (Variable.list version_encoding)))
(function
| Connect { gid ; port ; versions } ->
let port = match port with None -> 0 | Some port -> port in
Some (gid, port, versions)
| _ -> None)
(fun (gid, port, versions) ->
let port = if port = 0 then None else Some port in
Connect { gid ; port ; versions });
case ~tag:0x01 null
(function Disconnect -> Some () | _ -> None)
(fun () -> Disconnect);
case ~tag:0x02 null
(function Bootstrap -> Some () | _ -> None)
(fun () -> Bootstrap);
case ~tag:0x03 (Variable.list point_encoding)
(function Advertise points -> Some points | _ -> None)
(fun points -> Advertise points);
] @
ListLabels.map P.encodings
~f:(function Encoding { tag ; encoding ; wrap ; unwrap } ->
case ~tag encoding
(function Message msg -> unwrap msg | _ -> None)
(fun msg -> Message (wrap msg))))
let max_length = function
| 0 -> Some 1024
| 1 -> Some 0
| 2 -> Some 0
| 3 -> Some 0
| 4 -> Some (1 + 1000 * 17) (* tag + 1000 * max (point size) *)
| 5 -> Some 0
| n -> ListLabels.fold_left P.encodings ~init:None ~f:begin fun a -> function
Encoding { tag; max_length } -> if tag = n then max_length else a
end
module BE = EndianBigstring.BigEndian
let hdrlen = 2
let maxlen = hdrlen + 2 lsl 16
(* read a message from a TCP socket *)
let recv_msg fd buf =
catch (fun () ->
Lwt_bytes.recv fd buf 0 4 [ Lwt_unix.MSG_PEEK ] >>= fun hdrlen ->
if hdrlen <> 4 then begin
debug "read: could not read enough bytes to determine message size, aborting";
return None
end
else
Lwt_bytes.read fd buf 0 4 >>= fun _hdrlen ->
let len = Int32.to_int (BE.get_int32 buf 0) in
if len < 0 || len > MBytes.length buf then begin
debug "read: invalid message size %d" len;
return None
end
else
Lwt_utils.read_mbytes fd buf ~pos:4 ~len >|= fun () ->
let tag = BE.get_uint8 buf 4 in
let msg = MBytes.sub buf 4 len in
match max_length tag with
| Some len when MBytes.length msg > len -> None
| _ -> Data_encoding.Binary.of_bytes msg_encoding msg)
catch
(fun () ->
assert (MBytes.length buf >= 2 lsl 16) ;
Lwt_utils.read_mbytes ~len:hdrlen fd buf >>= fun () ->
let len = EndianBigstring.BigEndian.get_uint16 buf 0 in
(* TODO timeout read ??? *)
Lwt_utils.read_mbytes ~len fd buf >>= fun () ->
(* TODO conditionnaly decrypt payload... ?? *)
match Data_encoding.Binary.read msg_encoding buf 0 len with
| None ->
(* TODO track invalid message *)
return Disconnect
| Some (read, _) when read <> len ->
(* TODO track invalid message *)
return Disconnect
| Some (_, msg) ->
Lwt.return msg)
(function
| Unix.Unix_error (_err, _, _) -> return None
| Unix.Unix_error _ -> return Disconnect
| e -> fail e)
(* send a message over a TCP socket *)
let send_msg fd buf packet =
let send_msg fd buf msg =
catch
(fun () ->
match Data_encoding.Binary.write msg_encoding packet buf 4 with
match Data_encoding.Binary.write msg_encoding msg buf hdrlen with
| None -> return_false
| Some len ->
BE.set_int32 buf 0 @@ Int32.of_int (len - 4);
Lwt_utils.write_mbytes fd buf ~len >>= fun () ->
return_true
)
(fun exn -> Lwt.fail exn)
if len > maxlen then
return_false
else begin
EndianBigstring.BigEndian.set_int16 buf 0 (len - hdrlen) ;
(* TODO conditionnaly encrypt payload... ? *)
(* TODO timeout write ??? *)
Lwt_utils.write_mbytes ~len fd buf >>= fun () ->
return true
end)
(function
| Unix.Unix_error _ -> return_false
| e -> fail e)
(* A peer handle, as a record-encoded object, abstract from the
outside world. A hidden Lwt worker is associated to a peer at its
@ -377,17 +368,18 @@ module Make (P: PARAMS) = struct
let cancelation, cancel, on_cancel = canceler () in
(* a cancelable reception *)
let recv buf =
pick [ (recv_msg socket buf >|= function Some p -> p | None -> Disconnect);
pick [ recv_msg socket buf ;
(cancelation () >>= fun () -> return Disconnect) ] in
(* First step: send and receive credentials, makes no difference
whether we're trying to connect to a peer or checking an incoming
connection, both parties must first present themselves. *)
let rec connect buf =
send_msg socket buf (Connect { gid = my_gid ;
port = config.incoming_port ;
versions = P.supported_versions }) >>= fun _ ->
send_msg socket buf
(Connect { gid = my_gid ;
port = config.incoming_port ;
versions = P.supported_versions }) >>= fun _ ->
pick [ (LU.sleep limits.peer_answer_timeout >>= fun () -> return Disconnect) ;
recv buf ] >>= function
recv_msg socket buf ] >>= function
| Connect { gid; port = listening_port; versions } ->
debug "(%a) connection requested from %a @ %a:%d"
pp_gid my_gid pp_gid gid Ipaddr.pp_hum addr port ;
@ -458,7 +450,7 @@ module Make (P: PARAMS) = struct
(* Launch the worker *)
receiver ()
in
let buf = MBytes.create 0x100_000 in
let buf = MBytes.create maxlen in
on_cancel (fun () ->
send_msg socket buf Disconnect >>= fun _ ->
LU.close socket >>= fun _ ->
@ -542,8 +534,7 @@ module Make (P: PARAMS) = struct
(* A good random string so it is probably unique on the network *)
let fresh_gid () =
Bytes.to_string @@ Sodium.Random.Bytes.generate 16
Bytes.to_string @@ Sodium.Random.Bytes.generate gid_length
(* The (fixed size) broadcast frame. *)
let discovery_message_encoding =