diff --git a/src/lib_data_encoding/binary_error.ml b/src/lib_data_encoding/binary_error.ml index 502cd51f3..683c8f534 100644 --- a/src/lib_data_encoding/binary_error.ml +++ b/src/lib_data_encoding/binary_error.ml @@ -16,6 +16,7 @@ type read_error = | Invalid_int of { min : int ; v : int ; max : int } | Invalid_float of { min : float ; v : float ; max : float } | Trailing_zero + | Size_limit_exceeded let pp_read_error ppf = function | Not_enough_data -> @@ -34,6 +35,8 @@ let pp_read_error ppf = function Format.fprintf ppf "Invalid float (%f <= %f <= %f) " min v max | Trailing_zero -> Format.fprintf ppf "Trailing zero in Z" + | Size_limit_exceeded -> + Format.fprintf ppf "Size limit exceeded" exception Read_error of read_error diff --git a/src/lib_data_encoding/binary_error.mli b/src/lib_data_encoding/binary_error.mli index 1d43ee274..dc01588c8 100644 --- a/src/lib_data_encoding/binary_error.mli +++ b/src/lib_data_encoding/binary_error.mli @@ -19,6 +19,7 @@ type read_error = | Invalid_int of { min : int ; v : int ; max : int } | Invalid_float of { min : float ; v : float ; max : float } | Trailing_zero + | Size_limit_exceeded exception Read_error of read_error val pp_read_error: Format.formatter -> read_error -> unit diff --git a/src/lib_data_encoding/binary_length.ml b/src/lib_data_encoding/binary_length.ml index 991da65ea..feecf5fef 100644 --- a/src/lib_data_encoding/binary_length.ml +++ b/src/lib_data_encoding/binary_length.ml @@ -110,6 +110,10 @@ let rec length : type x. x Encoding.t -> x -> int = | Splitted { encoding = e } -> length e value | Dynamic_size e -> Binary_size.int32 + length e value + | Check_size { limit ; encoding = e } -> + let length = length e value in + if length > limit then raise (Write_error Size_limit_exceeded) ; + length | Delayed f -> length (f ()) value let fixed_length e = diff --git a/src/lib_data_encoding/binary_reader.ml b/src/lib_data_encoding/binary_reader.ml index bd2936d98..a17ba53cc 100644 --- a/src/lib_data_encoding/binary_reader.ml +++ b/src/lib_data_encoding/binary_reader.ml @@ -15,19 +15,26 @@ type state = { buffer : MBytes.t ; mutable offset : int ; mutable remaining_bytes : int ; + mutable allowed_bytes : int option ; } +let check_allowed_bytes state size = + match state.allowed_bytes with + | Some len when len < size -> raise Size_limit_exceeded + | Some len -> Some (len - size) + | None -> None + let check_remaining_bytes state size = if state.remaining_bytes < size then raise Not_enough_data ; state.remaining_bytes - size let read_atom size conv state = - let remaining_bytes = check_remaining_bytes state size in - let res = conv state.buffer state.offset in + let offset = state.offset in + state.remaining_bytes <- check_remaining_bytes state size ; + state.allowed_bytes <- check_allowed_bytes state size ; state.offset <- state.offset + size ; - state.remaining_bytes <- remaining_bytes ; - res + conv state.buffer offset (** Reader for all the atomic types. *) module Atom = struct @@ -179,6 +186,7 @@ let rec read_rec : type ret. ret Encoding.t -> state -> ret Some (read_rec e state) | Objs (`Fixed sz, e1, e2) -> ignore (check_remaining_bytes state sz : int) ; + ignore (check_allowed_bytes state sz : int option) ; let left = read_rec e1 state in let right = read_rec e2 state in (left, right) @@ -191,6 +199,7 @@ let rec read_rec : type ret. ret Encoding.t -> state -> ret | Tup e -> read_rec e state | Tups (`Fixed sz, e1, e2) -> ignore (check_remaining_bytes state sz : int) ; + ignore (check_allowed_bytes state sz : int option) ; let left = read_rec e1 state in let right = read_rec e2 state in (left, right) @@ -219,10 +228,31 @@ let rec read_rec : type ret. ret Encoding.t -> state -> ret if sz < 0 then raise (Invalid_size sz) ; let remaining = check_remaining_bytes state sz in state.remaining_bytes <- sz ; + ignore (check_allowed_bytes state sz : int option) ; let v = read_rec e state in if state.remaining_bytes <> 0 then raise Extra_bytes ; state.remaining_bytes <- remaining ; v + | Check_size { limit ; encoding = e } -> + let old_allowed_bytes = state.allowed_bytes in + let limit = + match state.allowed_bytes with + | None -> limit + | Some current_limit -> min current_limit limit in + state.allowed_bytes <- Some limit ; + let v = read_rec e state in + let allowed_bytes = + match old_allowed_bytes with + | None -> None + | Some old_limit -> + let remaining = + match state.allowed_bytes with + | None -> assert false + | Some remaining -> remaining in + let read = limit - remaining in + Some (old_limit - read) in + state.allowed_bytes <- allowed_bytes ; + v | Describe { encoding = e } -> read_rec e state | Def { encoding = e } -> read_rec e state | Splitted { encoding = e } -> read_rec e state @@ -267,7 +297,8 @@ and read_list : type a. a Encoding.t -> state -> a list let read encoding buffer ofs len = let state = - { buffer ; offset = ofs ; remaining_bytes = len } in + { buffer ; offset = ofs ; + remaining_bytes = len ; allowed_bytes = None } in match read_rec encoding state with | exception Read_error _ -> None | v -> Some (state.offset, v) @@ -275,7 +306,8 @@ let read encoding buffer ofs len = let of_bytes_exn encoding buffer = let len = MBytes.length buffer in let state = - { buffer ; offset = 0 ; remaining_bytes = len } in + { buffer ; offset = 0 ; + remaining_bytes = len ; allowed_bytes = None } in let v = read_rec encoding state in if state.offset <> len then raise Extra_bytes ; v diff --git a/src/lib_data_encoding/binary_stream_reader.ml b/src/lib_data_encoding/binary_stream_reader.ml index ef2fced19..9543560c6 100644 --- a/src/lib_data_encoding/binary_stream_reader.ml +++ b/src/lib_data_encoding/binary_stream_reader.ml @@ -22,6 +22,10 @@ type state = { illimited). Reading less bytes should raise [Extra_bytes] and trying to read more bytes should raise [Not_enough_data]. *) + allowed_bytes : int option ; + (** Maximum number of bytes that are allowed to be read from 'stream' + before to fail (None = illimited). *) + total_read : int ; (** Total number of bytes that has been read from [stream] since the beginning. *) @@ -41,6 +45,12 @@ let check_remaining_bytes state size = | Some len -> Some (len - size) | None -> None +let check_allowed_bytes state size = + match state.allowed_bytes with + | Some len when len < size -> raise Size_limit_exceeded + | Some len -> Some (len - size) + | None -> None + (** [read_atom resume size conv state k] reads [size] bytes from [state], pass it to [conv] to be decoded, and finally call the continuation [k] with the decoded value and the updated state. @@ -61,9 +71,10 @@ let check_remaining_bytes state size = let read_atom resume size conv state k = match let remaining_bytes = check_remaining_bytes state size in + let allowed_bytes = check_allowed_bytes state size in let res, stream = Binary_stream.read state.stream size in conv res.buffer res.ofs, - { remaining_bytes ; stream ; + { remaining_bytes ; allowed_bytes ; stream ; total_read = state.total_read + size } with | exception (Read_error error) -> Error error @@ -242,6 +253,7 @@ let rec read_rec k (Some v, state) | Objs (`Fixed sz, e1, e2) -> ignore (check_remaining_bytes state sz : int option) ; + ignore (check_allowed_bytes state sz : int option) ; read_rec e1 state @@ fun (left, state) -> read_rec e2 state @@ fun (right, state) -> k ((left, right), state) @@ -254,6 +266,7 @@ let rec read_rec | Tup e -> read_rec e state k | Tups (`Fixed sz, e1, e2) -> ignore (check_remaining_bytes state sz : int option) ; + ignore (check_allowed_bytes state sz : int option) ; read_rec e1 state @@ fun (left, state) -> read_rec e2 state @@ fun (right, state) -> k ((left, right), state) @@ -288,11 +301,31 @@ let rec read_rec else let remaining = check_remaining_bytes state sz in let state = { state with remaining_bytes = Some sz } in + ignore (check_allowed_bytes state sz : int option) ; read_rec e state @@ fun (v, state) -> if state.remaining_bytes <> Some 0 then Error Extra_bytes else k (v, { state with remaining_bytes = remaining }) + | Check_size { limit ; encoding = e } -> + let old_allowed_bytes = state.allowed_bytes in + let limit = + match state.allowed_bytes with + | None -> limit + | Some current_limit -> min current_limit limit in + let state = { state with allowed_bytes = Some limit } in + read_rec e state @@ fun (v, state) -> + let allowed_bytes = + match old_allowed_bytes with + | None -> None + | Some old_limit -> + let remaining = + match state.allowed_bytes with + | None -> assert false + | Some remaining -> remaining in + let read = limit - remaining in + Some (old_limit - read) in + k (v, { state with allowed_bytes }) | Describe { encoding = e } -> read_rec e state k | Def { encoding = e } -> read_rec e state k | Splitted { encoding = e } -> read_rec e state k @@ -362,6 +395,6 @@ let read_stream ?(init = Binary_stream.empty) encoding = invalid_arg "Data_encoding.Binary.read_stream: variable encoding" | `Dynamic | `Fixed _ -> (* No hardcoded read limit in a stream. *) - let state = { remaining_bytes = None ; + let state = { remaining_bytes = None ; allowed_bytes = None ; stream = init ; total_read = 0 } in read_rec encoding state success diff --git a/src/lib_data_encoding/binary_writer.ml b/src/lib_data_encoding/binary_writer.ml index c06835d8e..f6e23882f 100644 --- a/src/lib_data_encoding/binary_writer.ml +++ b/src/lib_data_encoding/binary_writer.ml @@ -268,6 +268,28 @@ let rec write_rec : type a. a Encoding.t -> state -> a -> unit = MBytes.set_int32 state.buffer (initial_offset - Binary_size.int32) (Int32.of_int size) + | Check_size { limit ; encoding = e } -> begin + (* backup the current limit *) + let old_limit = state.allowed_bytes in + (* install the new limit (only if smaller than the current limit) *) + let limit = + match state.allowed_bytes with + | None -> limit + | Some old_limit -> min old_limit limit in + state.allowed_bytes <- Some limit ; + write_rec e state value ; + (* restore the previous limit (minus the read bytes) *) + match old_limit with + | None -> + state.allowed_bytes <- None + | Some old_limit -> + let remaining = + match state.allowed_bytes with + | None -> assert false + | Some len -> len in + let read = limit - remaining in + state.allowed_bytes <- Some (old_limit - read) + end | Describe { encoding = e } -> write_rec e state value | Def { encoding = e } -> write_rec e state value | Splitted { encoding = e } -> write_rec e state value diff --git a/src/lib_data_encoding/data_encoding.mli b/src/lib_data_encoding/data_encoding.mli index 72ea2b715..45bafd39b 100644 --- a/src/lib_data_encoding/data_encoding.mli +++ b/src/lib_data_encoding/data_encoding.mli @@ -387,6 +387,12 @@ module Encoding: sig Usually used to fix errors from combining two encodings. *) val dynamic_size : 'a encoding -> 'a encoding + (** [check_size size encoding] ensures that the binary encoding + of a value will not be allowed to exceed [size] bytes. The reader and + and the writer fails otherwise. This function do not modify + the JSON encoding. *) + val check_size : int -> 'a encoding -> 'a encoding + (** Recompute the encoding definition each time it is used. Useful for dynamically updating the encoding of values of an extensible type via a global reference (e.g. exceptions). *) @@ -538,6 +544,7 @@ module Binary: sig | Invalid_int of { min : int ; v : int ; max : int } | Invalid_float of { min : float ; v : float ; max : float } | Trailing_zero + | Size_limit_exceeded exception Read_error of read_error val pp_read_error: Format.formatter -> read_error -> unit diff --git a/src/lib_data_encoding/encoding.ml b/src/lib_data_encoding/encoding.ml index ef50854dc..aac2bf7b7 100644 --- a/src/lib_data_encoding/encoding.ml +++ b/src/lib_data_encoding/encoding.ml @@ -107,6 +107,7 @@ type 'a desc = json_encoding : 'a Json_encoding.encoding ; is_obj : bool ; is_tup : bool } -> 'a desc | Dynamic_size : 'a t -> 'a desc + | Check_size : { limit : int ; encoding : 'a t } -> 'a desc | Delayed : (unit -> 'a t) -> 'a desc and _ field = @@ -170,6 +171,7 @@ let rec classify : type a. a t -> Kind.t = fun e -> | Def { encoding } -> classify encoding | Splitted { encoding } -> classify encoding | Dynamic_size _ -> `Dynamic + | Check_size { encoding } -> classify encoding | Delayed f -> classify (f ()) let make ?json_encoding encoding = { encoding ; json_encoding } @@ -200,6 +202,9 @@ end let dynamic_size e = make @@ Dynamic_size e +let check_size limit encoding = + make @@ Check_size { limit ; encoding } + let delayed f = make @@ Delayed f @@ -495,6 +500,7 @@ let rec is_nullable: type t. t encoding -> bool = fun e -> | Def { encoding = e } -> is_nullable e | Splitted { json_encoding } -> Json_encoding.is_nullable json_encoding | Dynamic_size e -> is_nullable e + | Check_size { encoding = e } -> is_nullable e | Delayed _ -> true let option ty = diff --git a/src/lib_data_encoding/encoding.mli b/src/lib_data_encoding/encoding.mli index e431838b3..29a95cb49 100644 --- a/src/lib_data_encoding/encoding.mli +++ b/src/lib_data_encoding/encoding.mli @@ -65,6 +65,7 @@ type 'a desc = json_encoding : 'a Json_encoding.encoding ; is_obj : bool ; is_tup : bool } -> 'a desc | Dynamic_size : 'a t -> 'a desc + | Check_size : { limit : int ; encoding : 'a t } -> 'a desc | Delayed : (unit -> 'a t) -> 'a desc and _ field = @@ -121,6 +122,7 @@ module Variable : sig val list : 'a encoding -> 'a list encoding end val dynamic_size : 'a encoding -> 'a encoding +val check_size : int -> 'a encoding -> 'a encoding val delayed : (unit -> 'a encoding) -> 'a encoding val req : ?title:string -> ?description:string -> diff --git a/src/lib_data_encoding/json.ml b/src/lib_data_encoding/json.ml index b36f118e6..32bd760d8 100644 --- a/src/lib_data_encoding/json.ml +++ b/src/lib_data_encoding/json.ml @@ -199,6 +199,7 @@ let rec json : type a. a Encoding.desc -> a Json_encoding.encoding = | Union (_tag_size, _, cases) -> union (List.map case_json cases) | Splitted { json_encoding } -> json_encoding | Dynamic_size e -> get_json e + | Check_size { encoding } -> get_json encoding | Delayed f -> get_json (f ()) and field_json diff --git a/src/lib_data_encoding/test/read_failure.ml b/src/lib_data_encoding/test/read_failure.ml index e999abbac..bab95be44 100644 --- a/src/lib_data_encoding/test/read_failure.ml +++ b/src/lib_data_encoding/test/read_failure.ml @@ -81,6 +81,22 @@ let stream ?(expected = fun _ -> true) read_encoding bytes () = Binary.pp_read_error error done +let minimal_stream ?(expected = fun _ -> true) expected_read read_encoding bytes () = + let name = "minimal_stream" in + match streamed_read read_encoding bytes with + | Binary.Success _, _ -> + Alcotest.failf "%s failed: expecting exception, got success." name + | Binary.Await _, _ -> + Alcotest.failf "%s failed: not enough data" name + | Binary.Error error, count when expected (Binary.Read_error error) && count = expected_read -> + () + | Binary.Error error, count -> + Alcotest.failf + "@[%s failed: read error after reading %d. @ %a@]" + name count + Binary.pp_read_error error + + let all ?expected name write_encoding read_encoding value = let json_value = Json.construct write_encoding value in let bson_value = Bson.construct write_encoding value in @@ -121,6 +137,36 @@ let all_ranged_float minimum maximum = all (name ^ ".min") float encoding (minimum -. 1.) @ all (name ^ ".max") float encoding (maximum +. 1.) +let test_bounded_string_list = + let expected = function + | Binary_error.Read_error Size_limit_exceeded -> true + | _ -> false in + let test name ~total ~elements v expected_read expected_read' = + let bytes = Binary.to_bytes_exn (Variable.list string) v in + let vbytes = Binary.to_bytes_exn (list string) v in + [ "bounded_string_list." ^ name, `Quick, + binary ~expected (bounded_list ~total ~elements string) bytes ; + "bounded_string_list_stream." ^ name, `Quick, + stream ~expected + (dynamic_size (bounded_list ~total:total ~elements string)) vbytes ; + "bounded_string_list_minimal_stream." ^ name, `Quick, + minimal_stream ~expected expected_read + (dynamic_size (bounded_list ~total:total ~elements string)) vbytes ; + "bounded_string_list_minimal_stream." ^ name, `Quick, + minimal_stream ~expected expected_read' + (check_size (total + 4) + (dynamic_size (Variable.list (check_size elements string)))) vbytes ; + + ] in + test "a" ~total:0 ~elements:0 [""] 4 4 @ + test "b1" ~total:3 ~elements:4 [""] 4 4 @ + test "b2" ~total:4 ~elements:3 [""] 4 4 @ + test "c1" ~total:19 ~elements:4 ["";"";"";"";""] 20 4 @ + test "c2" ~total:20 ~elements:3 ["";"";"";"";""] 4 4 @ + test "d1" ~total:20 ~elements:5 ["";"";"";"";"a"] 24 4 @ + test "d2" ~total:21 ~elements:4 ["";"";"";"";"a"] 24 24 @ + test "e" ~total:30 ~elements:10 ["ab";"c";"def";"gh";"ijk"] 32 4 + let tests = all_ranged_int 100 400 @ all_ranged_int 19000 19253 @ @@ -134,6 +180,7 @@ let tests = all "unknown_case.B" ~expected:missing_case union_enc mini_union_enc (B "2") @ all "unknown_case.E" ~expected:missing_case union_enc mini_union_enc E @ all "enum.missing" ~expected:missing_enum enum_enc mini_enum_enc 4 @ + test_bounded_string_list @ [ "z.truncated", `Quick, binary ~expected:not_enough_data z (MBytes.of_string "\x83") ; "z.trailing_zero", `Quick, diff --git a/src/lib_data_encoding/test/success.ml b/src/lib_data_encoding/test/success.ml index 23d67f1c1..a3c338c58 100644 --- a/src/lib_data_encoding/test/success.ml +++ b/src/lib_data_encoding/test/success.ml @@ -122,6 +122,18 @@ let test_string_enum_boundary () = run_test entries2 ; run_test (("256", 256) :: entries2) +let test_bounded_string_list = + let test name ~total ~elements v = + "bounded_string_list." ^ name, `Quick, + binary Alcotest.(list string) + (bounded_list ~total ~elements string) v in + [ test "a" ~total:0 ~elements:0 [] ; + test "b" ~total:4 ~elements:4 [""] ; + test "c" ~total:20 ~elements:4 ["";"";"";"";""] ; + test "d" ~total:21 ~elements:5 ["";"";"";"";"a"] ; + test "e" ~total:31 ~elements:10 ["ab";"c";"def";"gh";"ijk"] ; + ] + let tests = all "null" Alcotest.pass null () @ all "empty" Alcotest.pass empty () @ @@ -219,5 +231,6 @@ let tests = all "array" Alcotest.(array int) (array int31) [|1;2;3;4;5|] @ all "mu_list.empty" Alcotest.(list int) (mu_list_enc int31) [] @ all "mu_list" Alcotest.(list int) (mu_list_enc int31) [1;2;3;4;5] @ + test_bounded_string_list @ [ "string_enum_boundary", `Quick, test_string_enum_boundary ; ] diff --git a/src/lib_data_encoding/test/types.ml b/src/lib_data_encoding/test/types.ml index 2ff4e01b6..04190a813 100644 --- a/src/lib_data_encoding/test/types.ml +++ b/src/lib_data_encoding/test/types.ml @@ -162,6 +162,9 @@ let mu_list_enc enc = (fun (x, xs) -> x :: xs) ; ] +let bounded_list ~total ~elements enc = + check_size total (Variable.list (check_size elements enc)) + module Alcotest = struct include Alcotest let float = diff --git a/src/lib_data_encoding/test/write_failure.ml b/src/lib_data_encoding/test/write_failure.ml index 80e25481e..3e2c66ec3 100644 --- a/src/lib_data_encoding/test/write_failure.ml +++ b/src/lib_data_encoding/test/write_failure.ml @@ -51,6 +51,23 @@ let all_ranged_float minimum maximum = all (name ^ ".min") encoding (minimum -. 1.) @ all (name ^ ".max") encoding (maximum +. 1.) +let test_bounded_string_list = + let expected = function + | Binary_error.Write_error Size_limit_exceeded -> true + | _ -> false in + let test name ~total ~elements v = + "bounded_string_list." ^ name, `Quick, + binary ~expected (bounded_list ~total ~elements string) v in + [ test "a" ~total:0 ~elements:0 [""] ; + test "b1" ~total:3 ~elements:4 [""] ; + test "b2" ~total:4 ~elements:3 [""] ; + test "c1" ~total:19 ~elements:4 ["";"";"";"";""] ; + test "c2" ~total:20 ~elements:3 ["";"";"";"";""] ; + test "d1" ~total:20 ~elements:5 ["";"";"";"";"a"] ; + test "d2" ~total:21 ~elements:4 ["";"";"";"";"a"] ; + test "e" ~total:30 ~elements:10 ["ab";"c";"def";"gh";"ijk"] ; + ] + let tests = all_ranged_int 100 400 @ all_ranged_int 19000 19254 @ @@ -61,4 +78,5 @@ let tests = all "bytes.fixed" (Fixed.bytes 4) (MBytes.of_string "turlututu") @ all "unknown_case.B" mini_union_enc (B "2") @ all "unknown_case.E" mini_union_enc E @ + test_bounded_string_list @ []