diff --git a/src/lib_shell/mempool_worker.ml b/src/lib_shell/mempool_worker.ml index 8711a8844..48f756080 100644 --- a/src/lib_shell/mempool_worker.ml +++ b/src/lib_shell/mempool_worker.ml @@ -51,7 +51,7 @@ module type T = sig val shutdown : t -> unit Lwt.t (** parse a new operation and add it to the mempool context *) - val parse : t -> Operation.t -> operation tzresult + val parse : Operation.t -> operation tzresult (** validate a new operation and add it to the mempool context *) val validate : t -> operation -> result tzresult Lwt.t @@ -218,56 +218,61 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct (Format.pp_print_list Error_monad.pp) errors end - (* operations' cache. used for memoization *) - module Cache = struct + (* parsed operations' cache. used for memoization *) + module ParsedCache = struct - type t = { - operations : result Operation_hash.Table.t ; - parsed_operations : operation tzresult Operation_hash.Table.t ; - } + type t = operation tzresult Operation_hash.Table.t let encoding = - let open Data_encoding in - conv - (fun { operations ; parsed_operations } -> (operations, parsed_operations)) - (fun (operations, parsed_operations) -> { operations ; parsed_operations }) - (obj2 - (req "operations" (Operation_hash.Table.encoding result_encoding)) - (req "parsed_operations" - (Operation_hash.Table.encoding - (Error_monad.result_encoding operation_encoding))) - ) + (Operation_hash.Table.encoding + (Error_monad.result_encoding operation_encoding)) - let create () = - { operations = Operation_hash.Table.create 1000 ; - parsed_operations = Operation_hash.Table.create 1000 - } + let create () : t = + Operation_hash.Table.create 1000 - let add_validated t parsed_op result = - Operation_hash.Table.replace t.operations parsed_op.hash result - - let add_parsed t raw_op parsed_op = + let add t raw_op parsed_op = let hash = Operation.hash raw_op in - Operation_hash.Table.replace t.parsed_operations hash parsed_op + Operation_hash.Table.replace t hash parsed_op - let mem_validated t parsed_op = - Operation_hash.Table.mem t.operations parsed_op.hash - - let mem_parsed t raw_op = + let mem t raw_op = let hash = Operation.hash raw_op in - Operation_hash.Table.mem t.parsed_operations hash + Operation_hash.Table.mem t hash - let find_validated_opt t parsed_op = - Operation_hash.Table.find_opt t.operations parsed_op.hash - - let find_parsed_opt t raw_op = + let find_opt t raw_op = let hash = Operation.hash raw_op in - Operation_hash.Table.find_opt t.parsed_operations hash + Operation_hash.Table.find_opt t hash - let iter_validated f t = - Operation_hash.Table.iter f t.operations + let find_hash_opt t hash = + Operation_hash.Table.find_opt t hash - let to_mempool t = + let rem t hash = + Operation_hash.Table.remove t hash + + end + + (* validated operations' cache. used for memoization *) + module ValidatedCache = struct + + type t = result Operation_hash.Table.t + + let encoding = + Operation_hash.Table.encoding result_encoding + + let create () = Operation_hash.Table.create 1000 + + let add t parsed_op result = + Operation_hash.Table.replace t parsed_op.hash result + + let mem t parsed_op = + Operation_hash.Table.mem t parsed_op.hash + + let find_opt t parsed_op = + Operation_hash.Table.find_opt t parsed_op.hash + + let iter f t = + Operation_hash.Table.iter f t + + let to_mempool t parsed_cache = let empty = { Proto_services.Mempool.applied = [] ; refused = Operation_hash.Map.empty ; @@ -283,7 +288,7 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct { Proto.shell = op.shell ; protocol_data } in Operation_hash.Table.fold (fun hash result acc -> - match Operation_hash.Table.find_opt t.parsed_operations hash with + match ParsedCache.find_hash_opt parsed_cache hash with (* XXX this invariant should be better enforced *) | None | Some (Error _) -> assert false | Some (Ok op) -> begin @@ -319,11 +324,9 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct } | _ -> acc end) - t.operations empty + t empty - let clear t = - Operation_hash.Table.clear t.operations; - Operation_hash.Table.clear t.parsed_operations + let clear t = Operation_hash.Table.clear t end @@ -341,7 +344,7 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct (* state of the validator. this is updated at each apply_operation *) mutable validation_state : Proto.validation_state ; - cache : Cache.t ; + cache : ValidatedCache.t ; (* live blocks and operations, initialized at worker launch *) live_blocks : Block_hash.Set.t ; @@ -356,7 +359,7 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct parameters : parameters ; } - type view = { cache : Cache.t } + type view = { cache : ValidatedCache.t } let view (state : state) _ : view = { cache = state.cache } @@ -365,7 +368,7 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct conv (fun { cache } -> cache) (fun cache -> { cache }) - Cache.encoding + ValidatedCache.encoding let pp ppf _view = Format.fprintf ppf "lots of operations" @@ -378,6 +381,8 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct type t = Worker.infinite Worker.queue Worker.t + let parsed_cache = ParsedCache.create () + let debug w = Format.kasprintf (fun msg -> Worker.record_event w (Debug msg)) @@ -443,7 +448,7 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct (*** end prevalidation ***) - let parse_helper (_ : t) raw_op = + let parse_helper raw_op = let hash = Operation.hash raw_op in let size = Data_encoding.Binary.length Operation.encoding raw_op in if size > Proto.max_operation_data_length then @@ -480,10 +485,10 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct (* memoization is done only at on_* level *) let on_validate w parsed_op = let state = Worker.state w in - match Cache.find_validated_opt state.cache parsed_op with + match ValidatedCache.find_opt state.cache parsed_op with | None | Some (Branch_delayed _) -> validate_helper w parsed_op >>= fun result -> - Cache.add_validated state.cache parsed_op result; + ValidatedCache.add state.cache parsed_op result; (* operations are notified only the first time *) notify_helper w result parsed_op.raw ; Lwt.return result @@ -500,9 +505,13 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct Chain.data chain_state >>= fun { current_mempool = _mempool ; live_blocks ; live_operations } -> + (* remove all operations that are already included *) + Operation_hash.Set.iter (fun hash -> + ParsedCache.rem parsed_cache hash + ) live_operations; Lwt.return { validation_state ; - cache = Cache.create () ; + cache = ValidatedCache.create () ; live_blocks ; live_operations ; operation_stream = Lwt_watcher.create_input (); @@ -512,11 +521,11 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct let on_close w = let state = Worker.state w in Lwt_watcher.shutdown_input state.operation_stream; - Cache.iter_validated (fun hash _ -> + ValidatedCache.iter (fun hash _ -> Distributed_db.Operation.clear_or_cancel state.parameters.chain_db hash) state.cache ; - Cache.clear state.cache; + ValidatedCache.clear state.cache; Lwt.return_unit let on_error w r st errs = @@ -557,12 +566,11 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct Worker.push_request_and_wait t (Request.Validate parsed_op) (* atomic parse + memoization *) - let parse t raw_op = - let state = Worker.state t in - begin match Cache.find_parsed_opt state.cache raw_op with + let parse raw_op = + begin match ParsedCache.find_opt parsed_cache raw_op with | None -> - let parsed_op = parse_helper t raw_op in - Cache.add_parsed state.cache raw_op parsed_op; + let parsed_op = parse_helper raw_op in + ParsedCache.add parsed_cache raw_op parsed_op; parsed_op | Some parsed_op -> parsed_op end @@ -577,7 +585,7 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct (Proto_services.S.Mempool.pending_operations RPC_path.open_root) (fun w () () -> let state = Worker.state w in - RPC_answer.return (Cache.to_mempool state.cache) + RPC_answer.return (ValidatedCache.to_mempool state.cache parsed_cache) ) let monitor_rpc_directory : t RPC_directory.t = diff --git a/src/lib_shell/mempool_worker.mli b/src/lib_shell/mempool_worker.mli index e3def9ea2..be9fdd4a4 100644 --- a/src/lib_shell/mempool_worker.mli +++ b/src/lib_shell/mempool_worker.mli @@ -51,8 +51,8 @@ module type T = sig val create : limits -> Distributed_db.chain_db -> t tzresult Lwt.t val shutdown : t -> unit Lwt.t - (** parse a new operation and add it to the mempool context *) - val parse : t -> Operation.t -> operation tzresult + (** parse a new operation *) + val parse : Operation.t -> operation tzresult (** validate a new operation and add it to the mempool context *) val validate : t -> operation -> result tzresult Lwt.t