diff --git a/src/node/net/p2p_connection.ml b/src/node/net/p2p_connection.ml index 07fbc762a..8c00b04b5 100644 --- a/src/node/net/p2p_connection.ml +++ b/src/node/net/p2p_connection.ml @@ -32,17 +32,53 @@ type error += Rejected type error += Decoding_error type error += Myself of Id_point.t type error += Not_enough_proof_of_work of Gid.t +type error += Invalid_auth -type cryptobox_data = { - channel_key : Crypto_box.channel_key ; - mutable local_nonce : Crypto_box.nonce ; - mutable remote_nonce : Crypto_box.nonce ; -} +module Crypto = struct -let header_length = 2 -let crypto_overhead = 18 (* FIXME import from Sodium.Box. *) -let max_content_length = - 1 lsl (header_length * 8) - crypto_overhead + let header_length = 2 + let crypto_overhead = 18 (* FIXME import from Sodium.Box. *) + let max_content_length = + 1 lsl (header_length * 8) - crypto_overhead + + type data = { + channel_key : Crypto_box.channel_key ; + mutable local_nonce : Crypto_box.nonce ; + mutable remote_nonce : Crypto_box.nonce ; + } + + let write_chunk fd cryptobox_data buf = + let header_buf = MBytes.create header_length in + let local_nonce = cryptobox_data.local_nonce in + cryptobox_data.local_nonce <- Crypto_box.increment_nonce local_nonce ; + let encrypted_message = + Crypto_box.fast_box cryptobox_data.channel_key buf local_nonce in + let encrypted_len = MBytes.length encrypted_message in + fail_unless + (encrypted_len < max_content_length) + Invalid_message_size >>=? fun () -> + MBytes.set_int16 header_buf 0 encrypted_len ; + P2p_io_scheduler.write fd header_buf >>=? fun () -> + P2p_io_scheduler.write fd encrypted_message >>=? fun () -> + return () + + let read_chunk fd cryptobox_data = + let header_buf = MBytes.create header_length in + P2p_io_scheduler.read_full ~len:header_length fd header_buf >>=? fun () -> + let len = MBytes.get_uint16 header_buf 0 in + let buf = MBytes.create len in + P2p_io_scheduler.read_full ~len fd buf >>=? fun () -> + let remote_nonce = cryptobox_data.remote_nonce in + cryptobox_data.remote_nonce <- Crypto_box.increment_nonce remote_nonce ; + match + Crypto_box.fast_box_open cryptobox_data.channel_key buf remote_nonce + with + | None -> + fail Decipher_error + | Some buf -> + return buf + +end module Connection_message = struct @@ -78,11 +114,12 @@ module Connection_message = struct let encoded_message_len = Data_encoding.Binary.length encoding message in fail_unless - (encoded_message_len < max_content_length) + (encoded_message_len < Crypto.max_content_length) Encoding_error >>=? fun () -> - let len = header_length + encoded_message_len in + let len = Crypto.header_length + encoded_message_len in let buf = MBytes.create len in - match Data_encoding.Binary.write encoding message buf header_length with + match Data_encoding.Binary.write + encoding message buf Crypto.header_length with | None -> fail Encoding_error | Some last -> @@ -91,8 +128,9 @@ module Connection_message = struct P2p_io_scheduler.write fd buf let read fd = - let header_buf = MBytes.create header_length in - P2p_io_scheduler.read_full ~len:header_length fd header_buf >>=? fun () -> + let header_buf = MBytes.create Crypto.header_length in + P2p_io_scheduler.read_full + ~len:Crypto.header_length fd header_buf >>=? fun () -> let len = MBytes.get_uint16 header_buf 0 in let buf = MBytes.create len in P2p_io_scheduler.read_full ~len fd buf >>=? fun () -> @@ -109,29 +147,25 @@ end module Ack = struct - type t = bool + type t = Ack | Nack let ack = MBytes.of_string "\255" let nack = MBytes.of_string "\000" - let write fd b = - match b with - | true -> - P2p_io_scheduler.write fd ack - | false -> - P2p_io_scheduler.write fd nack + let write cryptobox_data fd b = + Crypto.write_chunk cryptobox_data fd + (match b with Ack -> ack | Nack -> nack) - let read fd = - let buf = MBytes.create 1 in - P2p_io_scheduler.read_full fd buf >>=? fun () -> + let read fd cryptobox_data = + Crypto.read_chunk fd cryptobox_data >>=? fun buf -> return (buf <> nack) end type authenticated_fd = - P2p_io_scheduler.connection * Connection_info.t * cryptobox_data + P2p_io_scheduler.connection * Connection_info.t * Crypto.data -let kick (fd, _ , _) = - Ack.write fd false >>= fun _ -> +let kick (fd, _ , cryptobox_data) = + Ack.write fd cryptobox_data Nack >>= fun _ -> P2p_io_scheduler.close fd >>= fun _ -> Lwt.return_unit @@ -168,14 +202,14 @@ let authenticate { Connection_info.gid = remote_gid ; versions = msg.versions ; incoming ; id_point ; remote_socket_port ;} in let cryptobox_data = - { channel_key ; local_nonce ; + { Crypto.channel_key ; local_nonce ; remote_nonce = msg.message_nonce } in return (info, (fd, info, cryptobox_data)) type connection = { info : Connection_info.t ; fd : P2p_io_scheduler.connection ; - cryptobox_data : cryptobox_data ; + cryptobox_data : Crypto.data ; } module Reader = struct @@ -188,29 +222,13 @@ module Reader = struct mutable worker: unit Lwt.t ; } - let read_chunk { fd ; cryptobox_data } = - let header_buf = MBytes.create header_length in - P2p_io_scheduler.read_full ~len:header_length fd header_buf >>=? fun () -> - let len = MBytes.get_uint16 header_buf 0 in - let buf = MBytes.create len in - P2p_io_scheduler.read_full ~len fd buf >>=? fun () -> - let remote_nonce = cryptobox_data.remote_nonce in - cryptobox_data.remote_nonce <- Crypto_box.increment_nonce remote_nonce ; - match - Crypto_box.fast_box_open cryptobox_data.channel_key buf remote_nonce - with - | None -> - fail Decipher_error - | Some buf -> - return buf - let rec read_message st buf = return (Data_encoding.Binary.of_bytes st.encoding buf) let rec worker_loop st = Lwt_unix.yield () >>= fun () -> Lwt_utils.protect ~canceler:st.canceler begin fun () -> - read_chunk st.conn >>=? fun buf -> + Crypto.read_chunk st.conn.fd st.conn.cryptobox_data >>=? fun buf -> read_message st buf end >>= function | Ok None -> @@ -258,21 +276,6 @@ module Writer = struct mutable worker: unit Lwt.t ; } - let write_chunk { cryptobox_data ; fd } buf = - let header_buf = MBytes.create header_length in - let local_nonce = cryptobox_data.local_nonce in - cryptobox_data.local_nonce <- Crypto_box.increment_nonce local_nonce ; - let encrypted_message = - Crypto_box.fast_box cryptobox_data.channel_key buf local_nonce in - let encrypted_len = MBytes.length encrypted_message in - fail_unless - (encrypted_len < max_content_length) - Invalid_message_size >>=? fun () -> - MBytes.set_int16 header_buf 0 encrypted_len ; - P2p_io_scheduler.write fd header_buf >>=? fun () -> - P2p_io_scheduler.write fd encrypted_message >>=? fun () -> - return () - let encode_message st msg = try return (Data_encoding.Binary.to_bytes st.encoding msg) with _ -> fail Encoding_error @@ -282,7 +285,7 @@ module Writer = struct Lwt_utils.protect ~canceler:st.canceler begin fun () -> Lwt_pipe.pop st.messages >>= fun (msg, wakener) -> encode_message st msg >>=? fun buf -> - write_chunk st.conn buf >>= fun res -> + Crypto.write_chunk st.conn.fd st.conn.cryptobox_data buf >>= fun res -> iter_option wakener ~f:(fun u -> Lwt.wakeup_later u res) ; Lwt.return res end >>= function @@ -332,11 +335,14 @@ let accept ?incoming_message_queue_size ?outgoing_message_queue_size (fd, info, cryptobox_data) encoding = Lwt_utils.protect begin fun () -> - Ack.write fd true >>=? fun () -> - Ack.read fd + Ack.write fd cryptobox_data Ack >>=? fun () -> + Ack.read fd cryptobox_data end ~on_error:begin fun err -> P2p_io_scheduler.close fd >>= fun _ -> - Lwt.return (Error err) + match err with + | [ P2p_io_scheduler.Connection_closed ] -> fail Rejected + | [ Decipher_error ] -> fail Invalid_auth + | err -> Lwt.return (Error err) end >>=? fun accepted -> fail_unless accepted Rejected >>=? fun () -> let canceler = Canceler.create () in diff --git a/src/node/net/p2p_connection.mli b/src/node/net/p2p_connection.mli index 8d335a68c..890cc6c34 100644 --- a/src/node/net/p2p_connection.mli +++ b/src/node/net/p2p_connection.mli @@ -26,6 +26,7 @@ type error += Decoding_error type error += Rejected type error += Myself of Id_point.t type error += Not_enough_proof_of_work of Gid.t +type error += Invalid_auth type authenticated_fd (** Type of a connection that successfully passed the authentication diff --git a/test/test_p2p_connection.ml b/test/test_p2p_connection.ml index e0d84cbc8..2dda0293a 100644 --- a/test/test_p2p_connection.ml +++ b/test/test_p2p_connection.ml @@ -83,11 +83,17 @@ let simple_msg = let is_rejected = function | Error [P2p_connection.Rejected] -> true - | Ok _ | Error _ -> false + | Ok _ -> false + | Error err -> + log_notice "Error: %a" pp_print_error err ; + false let is_connection_closed = function | Error [P2p_io_scheduler.Connection_closed] -> true - | Ok _ | Error _ -> false + | Ok _ -> false + | Error err -> + log_notice "Error: %a" pp_print_error err ; + false let bytes_encoding = Data_encoding.Variable.bytes